diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md index 6182c1eca..f439d6223 100644 --- a/CHANGE_LOG.md +++ b/CHANGE_LOG.md @@ -1,3 +1,7 @@ +## v0.4.5 Dec. 2021 +refactory the files in example and adding chunk size searching. + + ### v0.4.4 Dec. 2021 The system is successfully evaluated on a multi-node system. The benchmark scripts are integrated with memory-centric tiling borrowed from DeepSpeed. diff --git a/README.md b/README.md index 7fe2ade2a..982bff236 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ For some detail explanation of the above example, please check the guide [here]( For more examples, please check [here](./examples). -A quick-start benchmark script is [here](./examples/run_bert.sh). It is executed with random generated data, therefore you do not need to prepare the real data. It also demostrated all of the optimization techniques for patricksatr. For more optimization tricks runing the benchmark see [Optimization Options](./doc/optimization_options.md). +A quick-start benchmark script is [here](./examples/run_transformers.sh). It is executed with random generated data, therefore you do not need to prepare the real data. It also demostrated all of the optimization techniques for patricksatr. For more optimization tricks runing the benchmark see [Optimization Options](./doc/optimization_options.md). ### Limitations diff --git a/examples/README.md b/examples/README.md index 0da456092..49b075280 100644 --- a/examples/README.md +++ b/examples/README.md @@ -19,12 +19,12 @@ python huggingface_bert.py ### Use PatrickStar to train large model -`run_bert.sh` and `pretrain_bert_demo.py` is an example to train large PTMs with PatrickStar. You could run different size of model by adding config to`run_bert.sh`. +`run_transformers.sh` and `pretrain_bert_demo.py` is an example to train large PTMs with PatrickStar. You could run different size of model by adding config to`run_transformers.sh`. The following command will run a model with 4B params: ```bash -env MODEL_NAME=GPT2_4B RES_CHECK=0 DIST_PLAN="patrickstar" bash run_bert.sh +env MODEL_NAME=GPT2_4B RES_CHECK=0 DIST_PLAN="patrickstar" bash run_transformers.sh ``` For the available `MODEL_NAME`, please check `pretrain_bert_demo.py`. @@ -32,7 +32,7 @@ For the available `MODEL_NAME`, please check `pretrain_bert_demo.py`. Check the accuracy of PatrickStar with Bert: ```bash -bash RES_CHECK=1 run_bert.sh +bash RES_CHECK=1 run_transformers.sh ``` ### MoE support @@ -44,3 +44,13 @@ python -m torch.distributed.launch --nproc_per_node=4 huggingface_bert_moe.py ``` Note that you need to install [FastMoE](https://github.com/laekov/fastmoe) before running this example. + + +### Search the best chunk size + +Chunk size (CS) is an important hyperparameter for patrickstar. +Although you can set an CS value empirically by run your training task serveral times. We provide an systemic way to find a CS with less memory footprint. Using the following command to search the chunk size. + +``` + env CS_SEARCH=1 bash run_transformers.sh +``` diff --git a/examples/benchmark/run_a100_benchmark_large_model.sh b/examples/benchmark/run_a100_benchmark_large_model.sh index 6870eb9cb..04d67970f 100644 --- a/examples/benchmark/run_a100_benchmark_large_model.sh +++ b/examples/benchmark/run_a100_benchmark_large_model.sh @@ -32,7 +32,7 @@ do echo "****************** Begin ***************************" echo "* benchmarking CS ${CS} BS ${BS} MODEL ${MODEL_NAME} " echo "* CPU_EBD ${CPU_EBD} SP ${SP} ACT_OFFLOAD ${ACT_OFFLOAD} MSC ${MSC} CACHE ${CACHE}" -bash ../run_bert.sh +bash ../run_transformers.sh echo "****************** Finished ***************************" echo "" echo "" diff --git a/examples/benchmark/run_a100_benchmark_small_model.sh b/examples/benchmark/run_a100_benchmark_small_model.sh index 01611d99a..a9f03342c 100644 --- a/examples/benchmark/run_a100_benchmark_small_model.sh +++ b/examples/benchmark/run_a100_benchmark_small_model.sh @@ -32,7 +32,7 @@ do echo "****************** Begin ***************************" echo "* benchmarking CS ${CS} BS ${BS} MODEL ${MODEL_NAME} " echo "* CPU_EBD ${CPU_EBD} SP ${SP} ACT_OFFLOAD ${ACT_OFFLOAD} MSC ${MSC} CACHE ${CACHE}" -bash ../run_bert.sh +bash ../run_transformers.sh echo "****************** Finished ***************************" echo "" echo "" diff --git a/examples/benchmark/run_benchmark.sh b/examples/benchmark/run_benchmark.sh index e98acc1a9..f205ff3c7 100644 --- a/examples/benchmark/run_benchmark.sh +++ b/examples/benchmark/run_benchmark.sh @@ -24,7 +24,7 @@ do echo "****************** Begin ***************************" echo "* benchmarking CS ${CS} BS ${BS} MODEL ${MODEL_NAME} " echo "* CPU_EBD ${CPU_EBD} SP ${SP} ACT_OFFLOAD ${ACT_OFFLOAD}" -bash ../run_bert.sh +bash ../run_transformers.sh echo "****************** Finished ***************************" echo "" echo "" diff --git a/examples/eval_chunk_size.py b/examples/eval_chunk_size.py new file mode 100644 index 000000000..9b74da0bb --- /dev/null +++ b/examples/eval_chunk_size.py @@ -0,0 +1,194 @@ +# BSD 3-Clause License +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the psutil authors nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import logging +import torch + +from patrickstar.utils.logging import logger, log_dist +from model_builder import build_transformer_model +from ps_config import get_patrickstar_config +from parse_args import parse_args +from patrickstar.core import PatrickStarClient +from patrickstar.core import PSPreProcessCtx +import time +from patrickstar.utils.distributed import get_rank + +MB_NUM = 1024 * 1024 +GB_NUM = 1024 * MB_NUM + +HARDWARE_SETTING_JSON = { + "per_cpu_mem": 240 * GB_NUM, + "per_gpu_mem": 32 * GB_NUM, + "global_gpu_num": 1, + "gloabl_cpu_num": 1, + "local_gpu_num": 1, + "local_cpu_num": 1, +} + + +def chunk_schema_valid_check(args, config, chunk_size, overall_chunk_size): + """ + check validation of a chunk schema, given the overall chunk size + args: + @args: cmd args + @config: client config + @chunk_size: the chunk size in numel + @overall_chunk_size: the overall chunk size used for param fp16 + returns: + bool: is the chunk schema valid + """ + per_gpu_mem = HARDWARE_SETTING_JSON.get("per_gpu_mem") + per_cpu_mem = HARDWARE_SETTING_JSON.get("per_cpu_mem") + global_gpu_num = HARDWARE_SETTING_JSON.get("global_gpu_num") + global_cpu_num = HARDWARE_SETTING_JSON.get("gloabl_cpu_num") + ava_per_gpu_mem = ( + per_gpu_mem + * config.get("overall_gpu_mem_ratio", 0.8) + * config.get("warmup_gpu_chunk_mem_ratio", 0.1) + ) + + ava_per_cpu_mem = per_cpu_mem * config.get("overall_cpu_mem_ratio", 0.8) + + # GPU mem has to host at least two chunks. + if ava_per_gpu_mem < chunk_size * 2: + logger.error( + "chunk is unable to be fitted in GPU during warmup!\n" + "GPU Mem %.2f MB vs. Two Chunks %.2f MB", + ava_per_gpu_mem / MB_NUM, + chunk_size * 2 / MB_NUM, + ) + return False + + # CPU + GPU shall not exceed the 14M (M numel of param) + overall_cpu_gpu_mem = ( + ava_per_gpu_mem * global_gpu_num + ava_per_cpu_mem * global_cpu_num + ) + need_mem = overall_chunk_size / 6 * 14 + if overall_cpu_gpu_mem < need_mem: + logger.error( + "Overall chunks can't fit in memory of CPU+GPU " "%.2f MB vs. %.2f MB", + overall_cpu_gpu_mem / MB_NUM, + need_mem / MB_NUM, + ) + return False + + logger.info( + "Evaluated chunk size %d Melem" + "ava_per_gpu_mem %.2f MB, " + "ava_per_cpu_mem %.2f MB, " + "need_mem %.2f MB\n", + args.default_chunk_size / MB_NUM, + ava_per_gpu_mem / MB_NUM, + ava_per_cpu_mem / MB_NUM, + need_mem / MB_NUM, + ) + return True + + +def get_param_used_chunk_size(args, config, model_func): + """ + return overall chunk size of param fp16 and param fp32. + as well as the memory utilization of chunks. + """ + client = PatrickStarClient( + rank=args.local_rank, + default_chunk_size=args.default_chunk_size, + config=config.get("client", None), + ) + start_time = time.time() + try: + with PSPreProcessCtx( + client=client, + dtype=torch.float, + release_after_init=args.release_after_init, + use_cpu_embedding=args.use_cpu_embedding, + not_init=True, + ): + model = model_func() + except Exception: + logger.error("PSPreProcessCtx failed") + return -1, -1 + end_time = time.time() + log_dist(f"PSPreProcessCtx Model Constructing elapse {end_time - start_time}") + del model + + overall_chunk_size, util = client.get_overall_chunk_size() + if chunk_schema_valid_check( + args, + config["client"]["mem_tracer"], + args.default_chunk_size, + overall_chunk_size, + ): + + return overall_chunk_size, util + else: + logger.error("Chunk schema validation check failed!") + return -1, -1 + + +def evaluate_chunk_size(args): + """ + Evaluate the current training task defined by the args. + write the chunk memory usage to the file. + """ + # Avoid gpu0 use more memory. + # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 + torch.cuda.set_device(args.local_rank) + torch.cuda.empty_cache() + + lr = 0.001 + betas = (0.9, 0.999) + eps = 1e-6 + weight_decay = 0 + + model_func, sequence_length = build_transformer_model(args) + config = get_patrickstar_config( + args, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay + ) + + overall_chunk_size, utils = get_param_used_chunk_size(args, config, model_func) + + logger.info( + "chunk uses %.2f MB, utilization %.2f \n", overall_chunk_size / MB_NUM, utils + ) + logger.info(f"writing to {args.slog_file}\n") + + if get_rank() == 0: + with open(f"{args.slog_file}", "a+") as fh: + fh.write( + f"{args.default_chunk_size/1024/1024} {overall_chunk_size/1024/1024}, {utils}\n" + ) + + +if __name__ == "__main__": + args = parse_args() + logger.setLevel(logging.INFO) + torch.manual_seed(0) + evaluate_chunk_size(args=args) diff --git a/examples/model_builder.py b/examples/model_builder.py new file mode 100644 index 000000000..59c856b5a --- /dev/null +++ b/examples/model_builder.py @@ -0,0 +1,234 @@ +# BSD 3-Clause License +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the psutil authors nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import transformers +from transformers import BertConfig +from packaging import version +import optimizations.global_opt_flags as global_opt_flags + + +def model_config(model_name): + """ + generate the model config according to the model name. + """ + if model_name == "Bert": + # 0.11B + HIDDEN_DIM = 768 + SEQ_LEN = 512 + NUM_LAYER = 6 + NUM_HEAD = 12 + elif model_name == "Bertlarge": + # 0.35B + HIDDEN_DIM = 1024 + SEQ_LEN = 512 + NUM_LAYER = 24 + NUM_HEAD = 16 + elif model_name == "GPT2small": + # 0.7B + HIDDEN_DIM = 1536 + SEQ_LEN = 128 + NUM_LAYER = 24 + NUM_HEAD = 16 + elif model_name == "GPT2_1B": + # 0.9B + HIDDEN_DIM = 2048 + SEQ_LEN = 1024 + NUM_LAYER = 20 + NUM_HEAD = 16 + elif model_name == "megatron_1.3B": + HIDDEN_DIM = 2048 + SEQ_LEN = 1024 + NUM_LAYER = 24 + NUM_HEAD = 32 + elif model_name == "GPT2_2B": + # zero-offload + HIDDEN_DIM = 2048 + SEQ_LEN = 1024 + NUM_LAYER = 40 + NUM_HEAD = 16 + elif model_name == "megatron_3.9B": + # Table 4 in Megatron Paper + HIDDEN_DIM = 2560 + SEQ_LEN = 1024 + NUM_LAYER = 24 + NUM_HEAD = 40 + elif model_name == "GPT2_4B": + HIDDEN_DIM = 2304 # 2048 + SEQ_LEN = 1024 + NUM_LAYER = 64 + NUM_HEAD = 16 + elif model_name == "GPT3_6B": + # 6.7B model + HIDDEN_DIM = 3072 + SEQ_LEN = 1024 + NUM_LAYER = 53 + NUM_HEAD = 16 + elif model_name == "GPT3_8B": + # 6.7B model + HIDDEN_DIM = 3072 + SEQ_LEN = 1024 + NUM_LAYER = 72 + NUM_HEAD = 16 + elif model_name == "GPT3_10B": + HIDDEN_DIM = 4096 + SEQ_LEN = 1024 + NUM_LAYER = 50 + NUM_HEAD = 16 + elif model_name == "GPT3_11B": + HIDDEN_DIM = 4096 + SEQ_LEN = 1024 + NUM_LAYER = 55 + NUM_HEAD = 16 + elif model_name == "GPT3_12B": + HIDDEN_DIM = 4096 + SEQ_LEN = 1024 + NUM_LAYER = 60 + NUM_HEAD = 16 + elif model_name == "GPT3_13B": + HIDDEN_DIM = 4096 + SEQ_LEN = 1024 + NUM_LAYER = 65 + NUM_HEAD = 16 + elif model_name == "GPT3_15B": + HIDDEN_DIM = 4096 + SEQ_LEN = 1024 + NUM_LAYER = 78 + NUM_HEAD = 16 + elif model_name == "GPT3_18B": + HIDDEN_DIM = 4096 + SEQ_LEN = 1024 + NUM_LAYER = 90 + NUM_HEAD = 16 + # The following configs comes from paper + # Efficient Large-Scale Language Model Training on GPU Clusters + # NV model is wider in hidden-size + elif model_name == "GPT_NV_18B": + HIDDEN_DIM = 6144 + SEQ_LEN = 1024 + NUM_LAYER = 40 + NUM_HEAD = 16 + elif model_name == "GPT_NV_39B": + HIDDEN_DIM = 8192 + SEQ_LEN = 1024 + NUM_LAYER = 48 + NUM_HEAD = 16 + elif model_name == "GPT_NV_76B": + HIDDEN_DIM = 10240 + SEQ_LEN = 1024 + NUM_LAYER = 60 + NUM_HEAD = 16 + # The following configs comes from Deep-Offload + # http://pasalabs.org/papers/2021/ATC21_zero-offload.pdf + elif model_name == "GPT_DS_20B": + HIDDEN_DIM = 8192 + SEQ_LEN = 1024 + NUM_LAYER = 25 + NUM_HEAD = 16 + elif model_name == "GPT_DS_40B": + HIDDEN_DIM = 8192 + SEQ_LEN = 1024 + NUM_LAYER = 50 + NUM_HEAD = 16 + elif model_name == "GPT_DS_50B": + HIDDEN_DIM = 8192 + SEQ_LEN = 1024 + NUM_LAYER = 62 + NUM_HEAD = 16 + elif model_name == "GPT_DS_60B": + HIDDEN_DIM = 8192 + SEQ_LEN = 1024 + NUM_LAYER = 75 + NUM_HEAD = 16 + elif model_name == "GPT_DS_70B": + HIDDEN_DIM = 9216 + SEQ_LEN = 1024 + NUM_LAYER = 69 + NUM_HEAD = 16 + else: + raise RuntimeError(f"The model name {model_name} is not valid!") + assert HIDDEN_DIM % NUM_HEAD == 0 + return (HIDDEN_DIM, SEQ_LEN, NUM_LAYER, NUM_HEAD) + + +def print_model_config(args, hidden_dim, sequence_len, num_layer, num_head): + if args.rank == 0: + config_dict = { + "hidden_dim": hidden_dim, + "sequence_len": sequence_len, + "num_layer": num_layer, + "num_head": num_head, + } + print("------------------ model config ------------------", flush=True) + str_list = [] + for key, value in config_dict.items(): + dots = "." * (32 - len(key)) + str_list.append(" {} {} {}".format(key, dots, value)) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print("-------------- end of model config --------------", flush=True) + + +def build_transformer_model(args): + """ + Build a transformer-based model based on transformer bert. + return a function able to build the model. + """ + if args.with_tiling_linear or args.with_activation_offload: + if args.with_tiling_linear: + global_opt_flags.USE_TILE = True + else: + global_opt_flags.USE_TILE = False + if args.with_activation_offload: + global_opt_flags.USE_ACT_OFFLOAD = True + else: + global_opt_flags.USE_ACT_OFFLOAD = False + from optimizations.ps_tile_modeling_bert import BertForSequenceClassification + else: + from transformers import BertForSequenceClassification + + hidden_dim, sequence_length, num_layer, num_head = model_config(args.model_name) + + bert_config = BertConfig( + gradient_checkpointing=args.use_ckp, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + ) + + def model_func(): + model = BertForSequenceClassification(bert_config) + if args.use_ckp and version.parse(transformers.__version__) >= version.parse( + "4.11.0" + ): + model.gradient_checkpointing_enable() + return model + + return model_func, sequence_length diff --git a/examples/optimizations/ps_tile_modeling_bert.py b/examples/optimizations/ps_tile_modeling_bert.py index 2cd9fa600..dbf22da99 100644 --- a/examples/optimizations/ps_tile_modeling_bert.py +++ b/examples/optimizations/ps_tile_modeling_bert.py @@ -584,7 +584,7 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = BertPooler(config) if add_pooling_layer else None - self.init_weights() + # self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings @@ -771,7 +771,7 @@ def __init__(self, config): self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() + # self.init_weights() def forward( self, diff --git a/examples/parse_args.py b/examples/parse_args.py new file mode 100644 index 000000000..7921588c7 --- /dev/null +++ b/examples/parse_args.py @@ -0,0 +1,196 @@ +# BSD 3-Clause License +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the psutil authors nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import os + + +def _add_patrick_star_args(parser): + group = parser.add_argument_group(title="patrickstar") + group.add_argument( + "--use_fake_dist", + dest="use_fake_dist", + action="store_true", + help="Using one GPU to stimulate multiple card.", + ) + group.add_argument( + "--default_chunk_size", + type=int, + default=32 * 1024 * 1024, + help="Default Chunk Size in elements.", + ) + group.add_argument( + "--use_cpu_embedding", + dest="use_cpu_embedding", + action="store_true", + help="Using CPU to perform Embedding and do not assign " + "embedding params to chunks", + ) + group.add_argument( + "--release_after_init", + action="store_true", + help="Release the remote chunk after the whole initialization." + "This would use more CPU memory during initialization, " + "but may fix some errors relate to checkpoint loading or" + "weight intialization.", + ) + group.add_argument( + "--use_hybrid_adam", + action="store_true", + help="Use hybrid adam optimization. " + "By default ADAM is on CPU and run ADAM on GPU if possible.", + ) + # Some hyperparams to tune when you failed to run a model. + group.add_argument( + "--with_static_partition", + action="store_true", + help="Use static partition for model data on CPU and GPU.", + ) + group.add_argument( + "--with_mem_profiler", + action="store_true", + help="Profiling memory usage.", + ) + group.add_argument( + "--init_loss_scale_power", + type=float, + default=10, + help="initial loss scale power", + ) + group.add_argument( + "--with_async_mem_monitor", + action="store_true", + help="Use async memory monitor.", + ) + group.add_argument( + "--with_mem_saving_comm", + action="store_true", + help="Use communication saving memory.", + ) + group.add_argument( + "--with_mem_cache", + action="store_true", + help="Use caching to allocate chunk payload.", + ) + group.add_argument( + "--with_async_move", + action="store_true", + help="Use asynchronize move.", + ) + group.add_argument( + "--slog_file", + type=str, + default="./slog_file/tmp.txt", + help="The file to record chunk size serach log.", + ) + return parser + + +def _add_general_opt_args(parser): + group = parser.add_argument_group(title="test_bert") + group.add_argument( + "--use_ckp", + dest="use_ckp", + action="store_true", + help="using gradient checkpointing for memory saveing.", + ) + group.add_argument( + "--with_activation_offload", + dest="with_activation_offload", + action="store_true", + help="Use activation offloading.", + ) + group.add_argument( + "--with_tiling_linear", + action="store_true", + help="Use linear tiling.", + ) + return parser + + +def _add_test_config_args(parser): + group = parser.add_argument_group(title="test_config") + group.add_argument( + "--batch_size", type=int, default=32, help="Batch size of input." + ) + group.add_argument( + "--local_rank", + type=int, + default=None, + help="local rank passed from distributed launcher.", + ) + group.add_argument( + "--res_check", + dest="res_check", + action="store_true", + help="check results correctness of checkpointing.", + ) + group.add_argument( + "--use_fp16", + dest="use_fp16", + action="store_true", + help="using FP16 for training.", + ) + group.add_argument( + "--dist_plan", + type=str, + default="torch", + help="Distributed Plan [torch, patrickstar]", + ) + group.add_argument( + "--model_name", type=str, default="GPTsmall", help="The model name." + ) + group.add_argument("--with_lightseq", action="store_true", help="use lightseq") + return parser + + +def _print_args(args): + """Print arguments.""" + if args.rank == 0: + print("------------------- arguments -------------------", flush=True) + str_list = [] + for arg in vars(args): + dots = "." * (32 - len(arg)) + str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print("---------------- end of arguments ----------------", flush=True) + + +def parse_args(): + """Parse all arguments.""" + parser = argparse.ArgumentParser(description="PatrickStar Arguments") + parser = _add_patrick_star_args(parser) + parser = _add_test_config_args(parser) + parser = _add_general_opt_args(parser) + args = parser.parse_args() + args.rank = int(os.getenv("RANK", "0")) + args.world_size = int(os.getenv("WORLD_SIZE", "1")) + _print_args(args) + return args diff --git a/examples/pretrain_bert_demo.py b/examples/pretrain_bert_demo.py index e5e3fcbc5..ca18a933c 100644 --- a/examples/pretrain_bert_demo.py +++ b/examples/pretrain_bert_demo.py @@ -28,355 +28,74 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import argparse import logging -import os import time -from packaging import version import torch import numpy as np -import transformers -from transformers import BertConfig import patrickstar.utils.global_timer as global_timer from data_loader import get_bert_data_loader from patrickstar.profiler import profiler from patrickstar.runtime import initialize_engine from patrickstar.utils import see_memory_usage -from patrickstar.utils.logging import logger -from patrickstar.utils.model_size_calculator import get_ps_model_size, estimate_bert_mac -import optimizations.global_opt_flags as global_opt_flags - - -def _add_patrick_star_args(parser): - group = parser.add_argument_group(title="patrickstar") - group.add_argument( - "--use_fake_dist", - dest="use_fake_dist", - action="store_true", - help="Using one GPU to stimulate multiple card.", - ) - group.add_argument( - "--default_chunk_size", - type=int, - default=32 * 1024 * 1024, - help="Default Chunk Size in elements.", - ) - group.add_argument( - "--use_cpu_embedding", - dest="use_cpu_embedding", - action="store_true", - help="Using CPU to perform Embedding and do not assign " - "embedding params to chunks", - ) - group.add_argument( - "--release_after_init", - action="store_true", - help="Release the remote chunk after the whole initialization." - "This would use more CPU memory during initialization, " - "but may fix some errors relate to checkpoint loading or" - "weight intialization.", - ) - group.add_argument( - "--use_hybrid_adam", - action="store_true", - help="Use hybrid adam optimization. " - "By default ADAM is on CPU and run ADAM on GPU if possible.", - ) - # Some hyperparams to tune when you failed to run a model. - group.add_argument( - "--with_static_partition", - action="store_true", - help="Use static partition for model data on CPU and GPU.", - ) - group.add_argument( - "--with_mem_profiler", - action="store_true", - help="Profiling memory usage.", - ) - group.add_argument( - "--init_loss_scale_power", - type=float, - default=10, - help="initial loss scale power", - ) - group.add_argument( - "--with_async_mem_monitor", - action="store_true", - help="Use async memory monitor.", - ) - group.add_argument( - "--with_mem_saving_comm", - action="store_true", - help="Use communication saving memory.", - ) - group.add_argument( - "--with_mem_cache", - action="store_true", - help="Use caching to allocate chunk payload.", - ) - group.add_argument( - "--with_async_move", - action="store_true", - help="Use asynchronize move.", - ) - return parser +from patrickstar.utils.logging import log_dist, logger +from patrickstar.utils.model_size_calculator import get_ps_model_size +from model_builder import build_transformer_model +from parse_args import parse_args +from ps_config import get_patrickstar_config -def _add_general_opt_args(parser): - group = parser.add_argument_group(title="test_bert") - group.add_argument( - "--use_ckp", - dest="use_ckp", - action="store_true", - help="using gradient checkpointing for memory saveing.", - ) - group.add_argument( - "--with_activation_offload", - dest="with_activation_offload", - action="store_true", - help="Use activation offloading.", - ) - group.add_argument( - "--with_tiling_linear", - action="store_true", - help="Use linear tiling.", - ) - return parser - - -def _add_test_config_args(parser): - group = parser.add_argument_group(title="test_config") - group.add_argument( - "--batch_size", type=int, default=32, help="Batch size of input." - ) - group.add_argument( - "--local_rank", - type=int, - default=None, - help="local rank passed from distributed launcher.", - ) - - group.add_argument( - "--res_check", - dest="res_check", - action="store_true", - help="check results correctness of checkpointing.", - ) - group.add_argument( - "--use_fp16", - dest="use_fp16", - action="store_true", - help="using FP16 for training.", - ) - group.add_argument( - "--dist_plan", - type=str, - default="torch", - help="Distributed Plan [torch, patrickstar]", - ) - group.add_argument( - "--model_name", type=str, default="GPTsmall", help="The model name." - ) - group.add_argument("--with_lightseq", action="store_true", help="use lightseq") - return parser - - -def _print_args(args): - """Print arguments.""" - if args.rank == 0: - print("------------------- arguments -------------------", flush=True) - str_list = [] - for arg in vars(args): - dots = "." * (32 - len(arg)) - str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg))) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print("---------------- end of arguments ----------------", flush=True) - - -def parse_args(): - """Parse all arguments.""" - parser = argparse.ArgumentParser(description="PatrickStar Arguments") - parser = _add_patrick_star_args(parser) - parser = _add_test_config_args(parser) - parser = _add_general_opt_args(parser) - args = parser.parse_args() - args.rank = int(os.getenv("RANK", "0")) - args.world_size = int(os.getenv("WORLD_SIZE", "1")) - _print_args(args) - return args - - -def print_model_config(hidden_dim, sequence_len, num_layer, num_head): - if args.rank == 0: - config_dict = { - "hidden_dim": hidden_dim, - "sequence_len": sequence_len, - "num_layer": num_layer, - "num_head": num_head, - } - print("------------------ model config ------------------", flush=True) - str_list = [] - for key, value in config_dict.items(): - dots = "." * (32 - len(key)) - str_list.append(" {} {} {}".format(key, dots, value)) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print("-------------- end of model config --------------", flush=True) - - -def test_bert_model_helper( +def test_transformer_model_helper( args, is_ckp: bool = False, is_fp16: bool = False, dist_plan: str = "torch", - batch_size=32, - hidden_dim=768, - sequence_length=256, - num_layer=12, - num_head=12, num_steps=5, ): - logger.info( - f'test a bert {"fp16" if is_fp16 else "fp32"} model ' - f'{"with checkpoint" if is_ckp else ""}' - ) - - # Use single card to imitate multicard. + # Use single card to simulate multicard. Used when you are poor and + # no more GPU avaiable. if args.use_fake_dist: rank = 0 else: rank = args.local_rank - if args.with_tiling_linear or args.with_activation_offload: - if args.with_tiling_linear: - global_opt_flags.USE_TILE = True - else: - global_opt_flags.USE_TILE = False - if args.with_activation_offload: - global_opt_flags.USE_ACT_OFFLOAD = True - else: - global_opt_flags.USE_ACT_OFFLOAD = False - from optimizations.ps_tile_modeling_bert import BertForSequenceClassification - else: - from transformers import BertForSequenceClassification - # Avoid gpu0 use more memory. # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 torch.cuda.set_device(rank) torch.cuda.empty_cache() - device = torch.device(f"cuda:{rank}") - bert_config = BertConfig( - gradient_checkpointing=is_ckp, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - ) + if args.with_mem_profiler: + print("start memory profiler") + profiler.start() lr = 0.001 betas = (0.9, 0.999) eps = 1e-6 weight_decay = 0 - if args.with_mem_profiler: - print("start memory profiler") - profiler.start() + model_func, sequence_length = build_transformer_model(args) if dist_plan == "patrickstar": if not is_fp16: logger.warning("PatrickStar will always use mixed precision training.") - - def model_func(): - model = BertForSequenceClassification(bert_config) - if is_ckp and version.parse(transformers.__version__) >= version.parse( - "4.11.0" - ): - model.gradient_checkpointing_enable() - return model - - config = { - # The same format as optimizer config of DeepSpeed - # https://www.deepspeed.ai/docs/config-json/#optimizer-parameters - "optimizer": { - "type": "Adam", - "params": { - "lr": lr, - "betas": betas, - "eps": eps, - "weight_decay": weight_decay, - "use_hybrid_adam": args.use_hybrid_adam, - }, - }, - "fp16": { - "enabled": True, - # Set "loss_scale" to 0 to use DynamicLossScaler. - "loss_scale": 0, - "initial_scale_power": args.init_loss_scale_power, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1, - }, - "default_chunk_size": args.default_chunk_size, - "release_after_init": args.release_after_init, - "use_fake_dist": args.use_fake_dist, - "use_cpu_embedding": args.use_cpu_embedding, - "client": { - "mem_tracer": { - "use_async_mem_monitor": args.with_async_mem_monitor, - "warmup_gpu_chunk_mem_ratio": 0.1, - "overall_gpu_mem_ratio": 0.8, - "overall_cpu_mem_ratio": 0.8, - "margin_use_ratio": 0.8, - "use_fake_dist": False, - "with_static_partition": args.with_static_partition, - }, - "opts": { - "with_mem_saving_comm": args.with_mem_saving_comm, - "with_mem_cache": args.with_mem_cache, - "with_async_move": args.with_async_move, - }, - }, - } - + config = get_patrickstar_config( + args, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay + ) model, optimizer = initialize_engine( model_func=model_func, local_rank=rank, config=config ) else: - model = BertForSequenceClassification(bert_config) + model = model_func() if args.with_mem_profiler: from patrickstar.core.torch_profiler_hook import ( register_torch_profiler_hook, ) register_torch_profiler_hook(model) - if is_ckp and version.parse(transformers.__version__) >= version.parse( - "4.11.0" - ): - model.gradient_checkpointing_enable() + model.cuda(rank) model.train() - if args.with_lightseq: - from optimizations.ls_hf_transformer_encoder_layer import ( - inject_ls_enc_layer, - ) - - inject_ls_enc_layer(model, args, bert_config) - print("Using Lightseq Kernels, all submodules includes:") - - def visit_and_register_hooks(module): - is_child_node = True - for _, submodule in module.named_children(): - visit_and_register_hooks(submodule) - is_child_node = False - if is_child_node: - print(f"module name {module.__class__.__name__}") - - visit_and_register_hooks(model) optimizer = torch.optim.Adam( model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay @@ -393,21 +112,18 @@ def visit_and_register_hooks(module): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) model_numel, model_num_param = get_ps_model_size(model) - logger.info(f"Model size {model_numel / 1e9} B, total params: {model_num_param}") - total_macs, nvidia_total_macs = estimate_bert_mac( - bert_config, batch_size, sequence_length, model_numel - ) - logger.info(f"Total MACs: {total_macs} TFlops") - logger.info(f"NVIDIA total MACs: {nvidia_total_macs}") - logger.debug(f"Diff csig/nvidia {total_macs / nvidia_total_macs}") + log_dist(f"Model size {model_numel / 1e9} B, total params: {model_num_param}") + total_macs = model_numel * args.batch_size * sequence_length * 2 * 4 + log_dist(f"Total MACs: {total_macs/1024/1024/1024/1024} TFlops") see_memory_usage( f"After model init. using {dist_plan}, gradient checkpoint: {is_ckp}, fp16 {is_fp16}", force=True, ) + # load data, here we generate random data for benchmarking. data_loader = get_bert_data_loader( - batch_size=batch_size, + batch_size=args.batch_size, total_samples=10000, sequence_length=sequence_length, device=device, @@ -417,14 +133,14 @@ def visit_and_register_hooks(module): loss_res = [] - print(f"MAC {total_macs / 1e9} GFlop, model param size: {model_numel / 1e9} B") + print(f"model param size: {model_numel / 1e9} B") for n, batch in enumerate(data_loader): if n == num_steps: break # You may need to empty_cache for really large models. torch.cuda.empty_cache() - logger.info(f"Start Step {n} with {dist_plan}...") + log_dist(f"Start Step {n} with {dist_plan}...") step_start_time = time.time() # Only collect running time of the last iteration. @@ -480,16 +196,15 @@ def visit_and_register_hooks(module): f"Step {n} elaspe {step_elapse} s, {total_macs / 1e12 / step_elapse} Tflops" ) - logger.info(f"End Step {n} with {dist_plan}.\n") + log_dist(f"End Step {n} with {dist_plan}.\n") if args.with_mem_profiler: profiler.end() if rank == 0: profiler.save( - f"{dist_plan}_{args.model_name}_bs_{batch_size}_" + f"{dist_plan}_{args.model_name}_bs_{args.batch_size}_" f"ckp_{is_ckp}_offload_{args.with_activation_offload}_profile.pkl" ) - logging.info("*" * 20) return loss_res @@ -503,7 +218,7 @@ def visit_and_register_hooks(module): # You could set the logger level to INFO to view more runtime # information. - logger.setLevel(logging.WARNING) + logger.setLevel(logging.INFO) if not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend="gloo" if args.use_fake_dist else "nccl" @@ -511,171 +226,17 @@ def visit_and_register_hooks(module): world_size = torch.distributed.get_world_size() - MODEL_NAME = args.model_name if res_check: - MODEL_NAME = "Bert" - if MODEL_NAME == "Bert": - # 0.11B - HIDDEN_DIM = 768 - SEQ_LEN = 512 - NUM_LAYER = 6 - NUM_HEAD = 12 - elif MODEL_NAME == "Bertlarge": - # 0.35B - HIDDEN_DIM = 1024 - SEQ_LEN = 512 - NUM_LAYER = 24 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT2small": - # 0.7B - HIDDEN_DIM = 1536 - SEQ_LEN = 128 - NUM_LAYER = 24 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT2_1B": - # 0.9B - HIDDEN_DIM = 2048 - SEQ_LEN = 1024 - NUM_LAYER = 20 - NUM_HEAD = 16 - elif MODEL_NAME == "megatron_1.3B": - HIDDEN_DIM = 2048 - SEQ_LEN = 1024 - NUM_LAYER = 24 - NUM_HEAD = 32 - elif MODEL_NAME == "GPT2_2B": - # zero-offload - HIDDEN_DIM = 2048 - SEQ_LEN = 1024 - NUM_LAYER = 40 - NUM_HEAD = 16 - elif MODEL_NAME == "megatron_3.9B": - # Table 4 in Megatron Paper - HIDDEN_DIM = 2560 - SEQ_LEN = 1024 - NUM_LAYER = 24 - NUM_HEAD = 40 - elif MODEL_NAME == "GPT2_4B": - HIDDEN_DIM = 2304 # 2048 - SEQ_LEN = 1024 - NUM_LAYER = 64 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_6B": - # 6.7B model - HIDDEN_DIM = 3072 - SEQ_LEN = 1024 - NUM_LAYER = 53 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_8B": - # 6.7B model - HIDDEN_DIM = 3072 - SEQ_LEN = 1024 - NUM_LAYER = 72 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_10B": - HIDDEN_DIM = 4096 - SEQ_LEN = 1024 - NUM_LAYER = 50 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_11B": - HIDDEN_DIM = 4096 - SEQ_LEN = 1024 - NUM_LAYER = 55 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_12B": - HIDDEN_DIM = 4096 - SEQ_LEN = 1024 - NUM_LAYER = 60 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_13B": - HIDDEN_DIM = 4096 - SEQ_LEN = 1024 - NUM_LAYER = 65 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_15B": - HIDDEN_DIM = 4096 - SEQ_LEN = 1024 - NUM_LAYER = 78 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT3_18B": - HIDDEN_DIM = 4096 - SEQ_LEN = 1024 - NUM_LAYER = 90 - NUM_HEAD = 16 - # The following configs comes from paper - # Efficient Large-Scale Language Model Training on GPU Clusters - # NV model is wider in hidden-size - elif MODEL_NAME == "GPT_NV_18B": - HIDDEN_DIM = 6144 - SEQ_LEN = 1024 - NUM_LAYER = 40 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT_NV_39B": - HIDDEN_DIM = 8192 - SEQ_LEN = 1024 - NUM_LAYER = 48 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT_NV_76B": - HIDDEN_DIM = 10240 - SEQ_LEN = 1024 - NUM_LAYER = 60 - NUM_HEAD = 16 - # The following configs comes from Deep-Offload - # http://pasalabs.org/papers/2021/ATC21_zero-offload.pdf - elif MODEL_NAME == "GPT_DS_20B": - HIDDEN_DIM = 8192 - SEQ_LEN = 1024 - NUM_LAYER = 25 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT_DS_40B": - HIDDEN_DIM = 8192 - SEQ_LEN = 1024 - NUM_LAYER = 50 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT_DS_50B": - HIDDEN_DIM = 8192 - SEQ_LEN = 1024 - NUM_LAYER = 62 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT_DS_60B": - HIDDEN_DIM = 8192 - SEQ_LEN = 1024 - NUM_LAYER = 75 - NUM_HEAD = 16 - elif MODEL_NAME == "GPT_DS_70B": - HIDDEN_DIM = 9216 - SEQ_LEN = 1024 - NUM_LAYER = 69 - NUM_HEAD = 16 - else: - raise RuntimeError(f"The model name {MODEL_NAME} is not valid!") - if res_check: - BATCH_SIZE = 2 - else: - BATCH_SIZE = args.batch_size - - assert HIDDEN_DIM % NUM_HEAD == 0 - logging.info(f"Benchmarking {MODEL_NAME}") - - print_model_config( - hidden_dim=HIDDEN_DIM, - sequence_len=SEQ_LEN, - num_layer=NUM_LAYER, - num_head=NUM_HEAD, - ) + args.model_name = "Bert" + args.batch_size = 2 if not res_check: torch.manual_seed(0) - loss_list = test_bert_model_helper( + loss_list = test_transformer_model_helper( args=args, is_ckp=use_ckp, is_fp16=use_fp16, dist_plan=dist_plan, - batch_size=BATCH_SIZE, - hidden_dim=HIDDEN_DIM, - sequence_length=SEQ_LEN, - num_layer=NUM_LAYER, - num_head=NUM_HEAD, num_steps=5, ) print("*" * 20 + " LOSS " + "*" * 20) @@ -689,16 +250,11 @@ def visit_and_register_hooks(module): NUM_STEPS = 5 torch.manual_seed(0) - torch_res_list = test_bert_model_helper( + torch_res_list = test_transformer_model_helper( args=args, is_ckp=use_ckp, is_fp16=False, dist_plan="torch", - hidden_dim=HIDDEN_DIM, - batch_size=BATCH_SIZE, - sequence_length=SEQ_LEN, - num_layer=NUM_LAYER, - num_head=NUM_HEAD, num_steps=NUM_STEPS, ) @@ -706,16 +262,11 @@ def visit_and_register_hooks(module): logging.info("-" * 50) torch.manual_seed(0) - autocast_res_list = test_bert_model_helper( + autocast_res_list = test_transformer_model_helper( args=args, is_ckp=use_ckp, is_fp16=True, dist_plan="torch", - hidden_dim=HIDDEN_DIM, - batch_size=BATCH_SIZE, - sequence_length=SEQ_LEN, - num_layer=NUM_LAYER, - num_head=NUM_HEAD, num_steps=NUM_STEPS, ) @@ -723,17 +274,11 @@ def visit_and_register_hooks(module): logging.info("-" * 50) torch.manual_seed(0) - ps_res_list = test_bert_model_helper( + ps_res_list = test_transformer_model_helper( args=args, is_ckp=use_ckp, is_fp16=use_fp16, dist_plan="patrickstar", - hidden_dim=HIDDEN_DIM, - batch_size=BATCH_SIZE, - sequence_length=SEQ_LEN, - num_layer=NUM_LAYER, - num_head=NUM_HEAD, - num_steps=NUM_STEPS, ) print("-" * 20 + " LOSS " + "-" * 20) diff --git a/examples/ps_config.py b/examples/ps_config.py new file mode 100644 index 000000000..0e5aeb366 --- /dev/null +++ b/examples/ps_config.py @@ -0,0 +1,78 @@ +# BSD 3-Clause License +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the psutil authors nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +def get_patrickstar_config( + args, lr=0.001, betas=(0.9, 0.999), eps=1e-6, weight_decay=0 +): + config = { + # The same format as optimizer config of DeepSpeed + # https://www.deepspeed.ai/docs/config-json/#optimizer-parameters + "optimizer": { + "type": "Adam", + "params": { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "use_hybrid_adam": args.use_hybrid_adam, + }, + }, + "fp16": { + "enabled": True, + # Set "loss_scale" to 0 to use DynamicLossScaler. + "loss_scale": 0, + "initial_scale_power": args.init_loss_scale_power, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + "default_chunk_size": args.default_chunk_size, + "release_after_init": args.release_after_init, + "use_fake_dist": args.use_fake_dist, + "use_cpu_embedding": args.use_cpu_embedding, + "client": { + "mem_tracer": { + "use_async_mem_monitor": args.with_async_mem_monitor, + "warmup_gpu_chunk_mem_ratio": 0.1, + "overall_gpu_mem_ratio": 0.9, + "overall_cpu_mem_ratio": 0.9, + "margin_use_ratio": 0.8, + "use_fake_dist": False, + "with_static_partition": args.with_static_partition, + }, + "opts": { + "with_mem_saving_comm": args.with_mem_saving_comm, + "with_mem_cache": args.with_mem_cache, + "with_async_move": args.with_async_move, + }, + }, + } + + return config diff --git a/examples/run_bert.sh b/examples/run_transformers.sh similarity index 83% rename from examples/run_bert.sh rename to examples/run_transformers.sh index 53da4329a..1c8f15624 100644 --- a/examples/run_bert.sh +++ b/examples/run_transformers.sh @@ -35,8 +35,8 @@ export CACHE=${CACHE:-1} export ASYNC_MOVE=${ASYNC_MOVE:-0} # linear tiling comm export TILING=${TILING:-0} - export LOCAL_WORLD_SIZE=${LOCAL_WORLD_SIZE:-1} +export CS_SEARCH=${CS_SEARCH:-0} if [[ ${TILING} == 1 ]]; then TILING_FLAG="--with_tiling_linear" @@ -102,14 +102,6 @@ else export RELEASE_AFTER_INIT_FLAG="" fi -export LIGHTSEQ=0 -if [[ ${LIGHTSEQ} == 1 ]]; then -export LIGHTSEQ_FLAG="--with_lightseq" -else -export LIGHTSEQ_FLAG="" -fi - - if [[ ${CKP} == 1 ]]; then export CKP_FLAG="--use_ckp" else @@ -123,7 +115,7 @@ LOG_DIR="./logs_${MODEL_NAME}" mkdir -p ${LOG_DIR} GIT_VER=`git rev-parse --short=5 HEAD` -LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_lightseq_${LIGHTSEQ}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}" +LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}" is_run_flag=`python ./benchmark/is_run_this_file.py --path "${LOG_DIR}" --file "${LOG_FILE}"` echo is_run_flag $is_run_flag @@ -149,8 +141,8 @@ fi wc=`cat /proc/cpuinfo | grep "processor"| wc -l` let TNUM=wc/${GPU_NUM} echo "CPU core number " $wc "THREAD NUM " ${TNUM} -env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} \ - pretrain_bert_demo.py \ + +cmd_opts=" --use_fp16 \ ${RES_CHECK_FLAG} \ ${NO_RETRY_FLAG} \ @@ -162,7 +154,6 @@ env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.l ${CPU_EBD_FLAG} \ ${HYBRID_ADAM_FLAG} \ ${RELEASE_AFTER_INIT_FLAG} \ - --default_chunk_size=${CHUNK_SIZE} \ ${LIGHTSEQ_FLAG} \ ${ACT_OFFLOAD_FLAG} \ ${SP_FLAG} \ @@ -172,4 +163,28 @@ env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.l ${CACHE_FLAG} \ ${ASYNC_MOVE_FLAG} \ ${TILING_FLAG} \ +" + +if [[ ${CS_SEARCH} == 1 ]]; then +mkdir -p ./search_res +SLOG_FILE="./search_res/slog_file.${MODEL_NAME}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}" +rm -rf ${SLOG_FILE} + +for((i=312;i>=64;i-=32)); +do +let CUR_CHUNK_SIZE=${i}*1024*1024 +echo "searching CHUNK_SIZE ${i} M elem" + +python -m torch.distributed.launch --nproc_per_node=1 \ + eval_chunk_size.py \ + --default_chunk_size=${CUR_CHUNK_SIZE} \ + --slog_file=${SLOG_FILE} \ + ${cmd_opts} +done +else +env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} \ + pretrain_bert_demo.py \ + --default_chunk_size=${CHUNK_SIZE} \ + ${cmd_opts} \ 2>&1 | tee ${LOG_DIR}/${LOG_FILE} +fi diff --git a/patrickstar/core/client.py b/patrickstar/core/client.py index 3015be784..0121fe269 100644 --- a/patrickstar/core/client.py +++ b/patrickstar/core/client.py @@ -31,7 +31,7 @@ import torch import patrickstar.utils.global_timer as global_timer -from patrickstar.utils import logger, get_world_size, get_rank +from patrickstar.utils import logger, get_world_size, get_rank, log_dist from .chunk_list import ChunkList, ChunkType from .chunk_tensor_index import ChunkTensorIndex from .const import AccessType, ChunkState, TensorState, TrainingStage @@ -91,7 +91,7 @@ def __init__(self, rank: int, default_chunk_size: int, config=None): self.opt_config["with_async_move"], ) if self.opt_config["with_mem_cache"]: - print("[CONFIG] USING MEM CACHE") + logger.debug("[CONFIG] USING MEM CACHE") self._time_profile = True if torch.distributed.is_initialized(): @@ -229,7 +229,7 @@ def append_dummy_chunk(self, data_type: torch.dtype, chunk_type: ChunkType): AccessType.DATA, ) - logger.info( + logger.debug( f"Append a dummy chunk to the Chunk List {chunk_type} " f"comm info {comm_info}" ) @@ -272,7 +272,7 @@ def append_tensor( chunk_id, param_list, access_type ): raise RuntimeError( - f"Can not append a tensor to chunk_tensor_index." + f"Can not append a tensor to chunk_tensor_index. " f"Overall size of param list is larger than the default chunk size {self.default_chunk_size}." ) return @@ -464,7 +464,7 @@ def _fetch_remote_chunks( "CLIENT_fetch_remote_chunks_allgather" ) - logger.info(f"rank {rank} allgather {chunk_id_list}") + logger.debug(f"rank {rank} allgather {chunk_id_list}") torch.distributed.all_gather( allgather_payload_buff, self.chunk_list[local_chunk_id].payload, @@ -910,21 +910,58 @@ def release_grad( def reset(self): raise NotImplementedError + def get_overall_chunk_size(self): + """ + return the overall size of all chunks and + the overall chunk utilization excluding fragments. + Excepting the dummy chunk if using MSC. + """ + overall_size = 0 + overall_chunk_num = 0 + overall_utilization_ratio = 0.0 + for ( + type, + type_chunk_list, + ) in self.chunk_tensor_index.chunk_type_to_chunk_id_list_map.items(): + + logger.debug(f"Chunk list {type}") + for chunk_id in type_chunk_list: + chunk = self.chunk_list[chunk_id] + if self.opt_config["with_mem_saving_comm"] and chunk.is_dummy(): + continue + comm_info = self.chunk_tensor_index.chunk_id_to_comm_info_map[chunk_id] + assert comm_info is not None + last_used_pos = 0 + for info in self.chunk_tensor_index.generate_tensor_info_in_order( + chunk_id + ): + last_used_pos = max(last_used_pos, info.start_offset + info.numel) + overall_utilization_ratio += last_used_pos / chunk.capacity + overall_size += chunk.get_chunk_space() + overall_chunk_num += 1 + overall_utilization_ratio /= overall_chunk_num + return overall_size, overall_utilization_ratio + def display_chunk_info(self): - logger.info("Print chunk list info.") + logger.debug("Print chunk list info.") overall_size = 0 + overall_chunk_num = 0 + overall_utilization_ratio = 0.0 + max_utilization_ratio = 0 for ( type, type_chunk_list, ) in self.chunk_tensor_index.chunk_type_to_chunk_id_list_map.items(): - logger.info(f"Chunk list {type}") + logger.debug(f"Chunk list {type}") for chunk_id in type_chunk_list: chunk = self.chunk_list[chunk_id] + if self.opt_config["with_mem_saving_comm"] and chunk.is_dummy(): + continue comm_info = self.chunk_tensor_index.chunk_id_to_comm_info_map[chunk_id] assert comm_info is not None - logger.info( + logger.debug( f"Chunk id {chunk.chunk_id}, state {chunk.get_state()}, " f"comm info {comm_info}, " f"capacity {chunk.capacity / 1024 / 1024} M elems, " @@ -935,16 +972,24 @@ def display_chunk_info(self): chunk_id ): assert info.chunk_id == chunk_id, f"{info.chunk_id} vs {chunk_id}" - logger.info( + logger.debug( f"** tensor: chunk_id {chunk_id}, start {info.start_offset}, " f"end {info.start_offset + info.numel}, size {info.numel}, " f"tensor_id {info.tensor_id}, state {info.state()}, name {info.tensor_name}" ) last_used_pos = max(last_used_pos, info.start_offset + info.numel) - logger.info( + logger.debug( f"chunk used {last_used_pos/1024/1024} M elem, " f"{last_used_pos/chunk.capacity * 100} %" ) + cur_util = last_used_pos / chunk.capacity + max_utilization_ratio = max(cur_util, max_utilization_ratio) + overall_utilization_ratio += cur_util overall_size += chunk.get_chunk_space() + overall_chunk_num += 1 - logger.info(f"OVERALL CHUNK SIZE {overall_size / 1024 / 1024 / 1024} GB") + log_dist(f"OVERALL CHUNK SIZE {overall_size / 1024 / 1024 / 1024} GB") + log_dist( + f"OVERALL UTILIZATION {overall_utilization_ratio / overall_chunk_num * 100} %" + ) + log_dist(f"MAX UTILIZATION {max_utilization_ratio * 100} %") diff --git a/patrickstar/core/memtracer/memtracer.py b/patrickstar/core/memtracer/memtracer.py index 2d67bacd5..cee477362 100644 --- a/patrickstar/core/memtracer/memtracer.py +++ b/patrickstar/core/memtracer/memtracer.py @@ -34,6 +34,7 @@ from patrickstar.core.const import TrainingStage from patrickstar.profiler import profiler from patrickstar.utils import ( + log_dist, get_memory_info, get_sys_memory_used, get_world_size, @@ -125,10 +126,6 @@ def __init__(self, local_rank: int = 0, config=None): if self.use_async_mem_monitor: self.async_mem_monitor = AsyncMemoryMonitor() - print( - f"[Mem Tracer] Using Asyn Mem Monitor Flag : {self.use_async_mem_monitor}" - ) - mem_info = get_memory_info() local_world_size = get_local_world_size() if self.use_fake_dist: @@ -150,7 +147,7 @@ def __init__(self, local_rank: int = 0, config=None): mem_info.total * self._overall_cpu_mem_ratio / local_world_size ) - logger.info( + log_dist( f"Init Manager over all gpu mem {self._overall_gpu_mem / 1e6} MB, " f"cpu mem {self._overall_cpu_mem / 1e6} MB" ) @@ -175,14 +172,14 @@ def close_tracer(self): """ if self.use_async_mem_monitor: self.async_mem_monitor.finish() - print("**** Memory Tracer is closed! ****") + log_dist("**** Memory Tracer is closed! ****") def start_train(self, param_fp16_chunk_size, chunk_size): self._param_fp16_chunk_size = param_fp16_chunk_size self._default_chunk_size = chunk_size if self.use_async_mem_monitor: self.async_mem_monitor.start() - print("**** Memory Tracer is stared! ****") + log_dist("**** Memory Tracer is stared! ****") def update_margin_mem(self): r"""Update the number of GPU free chunks for optimizer.""" @@ -193,6 +190,15 @@ def update_margin_mem(self): max_gpu_sys_used = 0 else: max_gpu_sys_used = max(self.gpu_sys_used_list) + + if len(self.cpu_sys_used_list) == 0: + logger.warning( + "No gpu info collected. Maybe there are no chunk based tensors." + ) + max_cpu_sys_used = 0 + else: + max_cpu_sys_used = max(self.cpu_sys_used_list) + margin_mem_size = ( self._overall_gpu_mem - max_gpu_sys_used - self._param_fp16_chunk_size ) @@ -201,14 +207,16 @@ def update_margin_mem(self): (margin_mem_size) / (self._default_chunk_size * 12) * self._margin_use_ratio ) - logger.info("--------------- GPU INFO AFTER BWD ----------------") - logger.info(f"Max GPU System Mem (non-chunk) Used {max_gpu_sys_used / 1e6} MB") - logger.info(f"Param FP16 Chunk Size {self._param_fp16_chunk_size / 1e6} MB") - logger.info( + log_dist("--------------- GPU INFO AFTER BWD ----------------") + log_dist(f"Max GPU System Mem (non-chunk) Used {max_gpu_sys_used / 1e6} MB") + log_dist(f"Max CPU System Mem (non-chunk) Used {max_cpu_sys_used / 1e6} MB") + log_dist(f"Param FP16 Chunk Size {self._param_fp16_chunk_size / 1e6} MB") + log_dist( f"Margin Mem Size {margin_mem_size / 1e6} MB, " f"available chunk num for Optimizer States {self._margin_chunk_num_for_gpu_adam}" ) - logger.info(f"OVERALL GPU MEM {self._overall_gpu_mem}") + log_dist("--------------- GPU INFO AFTER BWD ----------------") + logger.debug(f"OVERALL GPU MEM {self._overall_gpu_mem/1024/1024} MB") def reset_memory_stats(self): """ @@ -228,7 +236,7 @@ def reset_memory_stats(self): self.gpu_used_list = [] self.gpu_chunk_used_list = [] self.gpu_sys_used_list = [] - logger.info("Reset Memory Statistics") + log_dist("Reset Memory Statistics") def get_margin_chunk_num_for_gpu_adam(self): return self._margin_chunk_num_for_gpu_adam diff --git a/patrickstar/core/preprocess.py b/patrickstar/core/preprocess.py index 253ae8541..06885ed18 100644 --- a/patrickstar/core/preprocess.py +++ b/patrickstar/core/preprocess.py @@ -35,7 +35,7 @@ from patrickstar.core import register_param, is_param_registered, ParamType from patrickstar.manager import _runtime_config from patrickstar.ops import Embedding -from patrickstar.utils import logger, print_rank, get_rank, get_world_size +from patrickstar.utils import logger, log_dist, print_rank, get_rank, get_world_size from patrickstar.utils import see_memory_usage _orig_torch_empty = torch.empty @@ -219,6 +219,7 @@ def __init__( release_after_init=False, use_cpu_embedding=False, dtype=None, + not_init=False, ): super().__init__(config=None, dtype=dtype) self.rank = get_rank() @@ -231,6 +232,7 @@ def __init__( self.use_cpu_embedding = use_cpu_embedding self.submodule_id = -1 + self.not_init = not_init def _pre_context_exec(self): Embedding.use_cpu = self.use_cpu_embedding @@ -249,7 +251,7 @@ def _post_context_exec(self): number of processes. 3. Add a dummy param at the start of CPU Embedding for huggingface. """ - logger.info("Post Model Init Context") + log_dist("Post Model Init Context") def _origin_new(cls, *arg, **kwargs): return object.__new__(cls) @@ -297,24 +299,25 @@ def _origin_new(cls, *arg, **kwargs): param_fp32_chunk_id ), ): - if is_param_registered(param_fp32) and is_param_registered( - param_fp16 - ): - ps_data_fp16 = self.client.access_data( - param_fp16, torch.device("cpu:0") - ) - - ps_data_fp32 = self.client.access_data( - param_fp32, torch.device("cpu:0") - ) - - # Here the dtype of param_fp16 is actually fp32. - ps_data_fp16.copy_(param_fp16.data) - ps_data_fp32.copy_(param_fp16.data) - - self.client.release_data(param_fp16) - self.client.release_data(param_fp32) - param_fp16 = param_fp16.to(torch.half) + if not self.not_init: + if is_param_registered(param_fp32) and is_param_registered( + param_fp16 + ): + ps_data_fp16 = self.client.access_data( + param_fp16, torch.device("cpu:0") + ) + + ps_data_fp32 = self.client.access_data( + param_fp32, torch.device("cpu:0") + ) + + # Here the dtype of param_fp16 is actually fp32. + ps_data_fp16.copy_(param_fp16.data) + ps_data_fp32.copy_(param_fp16.data) + + self.client.release_data(param_fp16) + self.client.release_data(param_fp32) + param_fp16 = param_fp16.to(torch.half) else: for param_fp16 in self.client.chunk_tensor_index.params_generator( param_fp16_chunk_id @@ -330,7 +333,7 @@ def _origin_new(cls, *arg, **kwargs): chunk_num += 1 world_size = get_world_size() - logger.info(f"param fp16 chunk num {chunk_num}") + log_dist(f"Param fp16 chunk num {chunk_num}") while chunk_num % world_size != 0: self.client.append_dummy_chunk(torch.half, ChunkType.PARAM_FP16) chunk_num += 1 @@ -360,8 +363,6 @@ def _post_init_method(self, module): self.client.torch_param_allreduce_list.append(param) return - print_rank(f"Converting Params in {module.__class__.__name__}", force=False) - if not _runtime_config.use_chunk: for name, param in module.named_parameters(recurse=False): name = f"{module.__class__.__name__}.{name}_{self.param_idx}" diff --git a/patrickstar/ops/chunk_io_buff.py b/patrickstar/ops/chunk_io_buff.py index de8e37bfb..9218f0333 100644 --- a/patrickstar/ops/chunk_io_buff.py +++ b/patrickstar/ops/chunk_io_buff.py @@ -135,7 +135,7 @@ def reset(self): if self.cached_src_chunk_id is None: return global_rank = get_rank() - logger.info( + logger.debug( f"global_rank {global_rank} finally, write chunk {self.cached_target_chunk_id}" ) # It's possible that the chunk is empty (no payload), e.g. the process only possesses @@ -197,13 +197,13 @@ def __init__( gpu_device, chunk_size, torch.half, False ) else: - logger.info( + logger.debug( f"Allocate fp32 Chunk Buffer of size {chunk_size / 1e6} MB on CPU." ) self.gpu_payload = torch.empty( chunk_size, dtype=torch.half, device=gpu_device ) - logger.info( + logger.debug( f"Allocate fp32 Chunk Buffer of size {chunk_size / 1e6} MB on {gpu_device}." ) self.cached_chunk_id = None diff --git a/patrickstar/ops/fp16_cpu_adam.py b/patrickstar/ops/fp16_cpu_adam.py index 11610b374..111ce44f5 100644 --- a/patrickstar/ops/fp16_cpu_adam.py +++ b/patrickstar/ops/fp16_cpu_adam.py @@ -329,7 +329,7 @@ def fp16_chunk_adam_ops( Copy the chunk into a tmp buffer to speed up the memcpy between devices. """ local_rank = client.local_rank - logger.info( + logger.debug( f"local_rank {local_rank} margin_chunk_num_for_gpu_adam {margin_chunk_num_for_gpu_adam}, " f"param cnt {len(fp32_param_list)}" ) diff --git a/patrickstar/runtime/__init__.py b/patrickstar/runtime/__init__.py index 469fbdb29..030ed31bb 100644 --- a/patrickstar/runtime/__init__.py +++ b/patrickstar/runtime/__init__.py @@ -30,8 +30,9 @@ import torch from patrickstar.core import PSPreProcessCtx, PatrickStarClient from patrickstar.core.memtracer import RuntimeMemTracer -from patrickstar.utils import logger +from patrickstar.utils import logger, log_dist from .engine import PatrickStarEngine +import time DEFAULT_CHUNK_SIZE = 32 * 1024 * 1024 @@ -73,6 +74,8 @@ def initialize_engine(model_func, local_rank, config=None, client=None): config=config.get("client", None), ) + start_time = time.time() + log_dist("begin initialize the model parameters...") with PSPreProcessCtx( client=client, dtype=torch.float, @@ -80,6 +83,10 @@ def initialize_engine(model_func, local_rank, config=None, client=None): use_cpu_embedding=use_cpu_embedding, ): model = model_func() + end_time = time.time() + log_dist( + f"finished initialized the model parameters... {end_time - start_time} s" + ) engine = PatrickStarEngine(model=model, client=client, config=config) client.start_mem_tracer() diff --git a/patrickstar/runtime/engine.py b/patrickstar/runtime/engine.py index 7d97689a9..d235daafc 100644 --- a/patrickstar/runtime/engine.py +++ b/patrickstar/runtime/engine.py @@ -32,7 +32,7 @@ from patrickstar.core import ChunkState, TensorState, TrainingStage, ParamType from patrickstar.fp16 import LossScaler, DynamicLossScaler from patrickstar.ops import FP16Adam -from patrickstar.utils import logger, global_timer +from patrickstar.utils import log_dist, global_timer from .checkpoint import state_dict, load_state_dict from patrickstar.profiler import profiler @@ -86,7 +86,7 @@ def __init__(self, model, client, config): ), "Must have `loss_scale` field set." loss_scale = loss_scale_config["loss_scale"] if loss_scale == 0: - logger.info("Use DynamicLossScaler") + log_dist("Use DynamicLossScaler") self.loss_scaler = DynamicLossScaler( init_scale=( 2 ** loss_scale_config.get("initial_scale_power", 16) @@ -129,7 +129,7 @@ def __init__(self, model, client, config): self.iteration_cnt_ = 0 # TODO(jiaruifang) pass in via config. self.warmup_times = 1 - logger.info("PatrickStarEngine initialized.") + log_dist("PatrickStarEngine initialized.") def _move_torch_parts_to_gpu(self, model): # TODO(zilinzhu) Currently we move all buffers to GPU as the buffer size is diff --git a/patrickstar/utils/global_timer.py b/patrickstar/utils/global_timer.py index 240484840..74fd47ef4 100644 --- a/patrickstar/utils/global_timer.py +++ b/patrickstar/utils/global_timer.py @@ -29,7 +29,7 @@ import time -from .logging import logger +# from .logging import logger from .singleton_meta import SingletonMeta @@ -72,7 +72,7 @@ def reset(self): def print(self): if not self.start_flag: return - logger.info("------------- PROFILE RESULTS ----------------") + print("------------- PROFILE RESULTS ----------------") dot_length = 20 for k in self.elapse_stat: dot_length = max(dot_length, len(k) + 2) diff --git a/patrickstar/utils/logging.py b/patrickstar/utils/logging.py index bdc8a23f2..1dea2e72c 100644 --- a/patrickstar/utils/logging.py +++ b/patrickstar/utils/logging.py @@ -28,6 +28,7 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import logging import sys +from rich.logging import RichHandler import torch.distributed as dist @@ -57,14 +58,14 @@ def create_logger(name=None, level=logging.WARNING): logger_.propagate = False ch = logging.StreamHandler(stream=sys.stdout) ch.setFormatter(formatter) - logger_.addHandler(ch) + logger_.addHandler(RichHandler()) return logger_ logger = LoggerFactory.create_logger(name="PatrickStar", level=logging.WARNING) -def log_dist(message, ranks=None, level=logging.INFO): +def log_dist(message, ranks=[0], level=logging.INFO): """Log message when one of following condition meets + not dist.is_initialized() + dist.get_rank() in ranks if ranks is not None or ranks = [-1] diff --git a/requirements.txt b/requirements.txt index 377edef87..169c74355 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch>=1.5.0 pytest psutil ninja +rich diff --git a/setup.py b/setup.py index b999b98a1..0ed063765 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def fetch_requirements(path): setup( name="patrickstar", - version="0.4.4", + version="0.4.5", description="PatrickStart library", long_description="PatrickStar: Parallel Training of Large Language Models via a Chunk-based Parameter Server", long_description_content_type="text/markdown",