Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:Tencent/PatrickStar into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Dec 13, 2021
2 parents e9695ca + a1eac30 commit ee5c996
Show file tree
Hide file tree
Showing 14 changed files with 82 additions and 70 deletions.
18 changes: 7 additions & 11 deletions examples/eval_chunk_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,21 @@
import logging
import torch

# from patrickstar.utils.logging import logger
from patrickstar.utils.logging import logger, log_dist
from model_builder import build_transformer_model
from ps_config import get_patrickstar_config
from parse_args import parse_args
from patrickstar.core import PatrickStarClient
from patrickstar.core import PSPreProcessCtx

import time
from patrickstar.utils.distributed import get_rank

from rich.logging import RichHandler

logger = logging.getLogger(__name__)
logger.addHandler(RichHandler())

MB_NUM = 1024 * 1024
GB_NUM = 1024 * MB_NUM

HARDWARE_SETTING_JSON = {
"per_cpu_mem": 16 * GB_NUM,
"per_gpu_mem": 8 * GB_NUM,
"per_cpu_mem": 240 * GB_NUM,
"per_gpu_mem": 32 * GB_NUM,
"global_gpu_num": 1,
"gloabl_cpu_num": 1,
"local_gpu_num": 1,
Expand Down Expand Up @@ -127,7 +122,7 @@ def get_param_used_chunk_size(args, config, model_func):
default_chunk_size=args.default_chunk_size,
config=config.get("client", None),
)

start_time = time.time()
try:
with PSPreProcessCtx(
client=client,
Expand All @@ -140,7 +135,8 @@ def get_param_used_chunk_size(args, config, model_func):
except Exception:
logger.error("PSPreProcessCtx failed")
return -1, -1

end_time = time.time()
log_dist(f"PSPreProcessCtx Model Constructing elapse {end_time - start_time}")
del model

overall_chunk_size, util = client.get_overall_chunk_size()
Expand Down
4 changes: 2 additions & 2 deletions examples/optimizations/ps_tile_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def __init__(self, config, add_pooling_layer=True):

self.pooler = BertPooler(config) if add_pooling_layer else None

self.init_weights()
# self.init_weights()

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down Expand Up @@ -771,7 +771,7 @@ def __init__(self, config):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
# self.init_weights()

def forward(
self,
Expand Down
17 changes: 6 additions & 11 deletions examples/pretrain_bert_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from patrickstar.profiler import profiler
from patrickstar.runtime import initialize_engine
from patrickstar.utils import see_memory_usage
from patrickstar.utils.logging import logger
from patrickstar.utils.logging import log_dist, logger
from patrickstar.utils.model_size_calculator import get_ps_model_size
from model_builder import build_transformer_model
from parse_args import parse_args
Expand All @@ -53,11 +53,6 @@ def test_transformer_model_helper(
dist_plan: str = "torch",
num_steps=5,
):
logger.info(
f'test a bert {"fp16" if is_fp16 else "fp32"} model '
f'{"with checkpoint" if is_ckp else ""}'
)

# Use single card to simulate multicard. Used when you are poor and
# no more GPU avaiable.
if args.use_fake_dist:
Expand Down Expand Up @@ -117,9 +112,9 @@ def test_transformer_model_helper(
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

model_numel, model_num_param = get_ps_model_size(model)
logger.info(f"Model size {model_numel / 1e9} B, total params: {model_num_param}")
log_dist(f"Model size {model_numel / 1e9} B, total params: {model_num_param}")
total_macs = model_numel * args.batch_size * sequence_length * 2 * 4
logger.info(f"Total MACs: {total_macs/1024/1024/1024/1024} TFlops")
log_dist(f"Total MACs: {total_macs/1024/1024/1024/1024} TFlops")

see_memory_usage(
f"After model init. using {dist_plan}, gradient checkpoint: {is_ckp}, fp16 {is_fp16}",
Expand All @@ -145,7 +140,7 @@ def test_transformer_model_helper(
break
# You may need to empty_cache for really large models.
torch.cuda.empty_cache()
logger.info(f"Start Step {n} with {dist_plan}...")
log_dist(f"Start Step {n} with {dist_plan}...")

step_start_time = time.time()
# Only collect running time of the last iteration.
Expand Down Expand Up @@ -201,7 +196,7 @@ def test_transformer_model_helper(
f"Step {n} elaspe {step_elapse} s, {total_macs / 1e12 / step_elapse} Tflops"
)

logger.info(f"End Step {n} with {dist_plan}.\n")
log_dist(f"End Step {n} with {dist_plan}.\n")

if args.with_mem_profiler:
profiler.end()
Expand All @@ -223,7 +218,7 @@ def test_transformer_model_helper(

# You could set the logger level to INFO to view more runtime
# information.
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="gloo" if args.use_fake_dist else "nccl"
Expand Down
6 changes: 3 additions & 3 deletions examples/run_transformers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ LOG_DIR="./logs_${MODEL_NAME}"
mkdir -p ${LOG_DIR}

GIT_VER=`git rev-parse --short=5 HEAD`
LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_lightseq_${LIGHTSEQ}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}"
LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}"

is_run_flag=`python ./benchmark/is_run_this_file.py --path "${LOG_DIR}" --file "${LOG_FILE}"`
echo is_run_flag $is_run_flag
Expand Down Expand Up @@ -167,7 +167,7 @@ cmd_opts="

if [[ ${CS_SEARCH} == 1 ]]; then
mkdir -p ./search_res
SLOG_FILE="./search_res/slog_file.${MODEL_NAME}_bs_${BS}_cpueb_${CPU_EBD}_lightseq_${LIGHTSEQ}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}"
SLOG_FILE="./search_res/slog_file.${MODEL_NAME}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}"
rm -rf ${SLOG_FILE}

for((i=312;i>=64;i-=32));
Expand All @@ -185,6 +185,6 @@ else
env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} \
pretrain_bert_demo.py \
--default_chunk_size=${CHUNK_SIZE} \
${cmd_opts}
${cmd_opts} \
2>&1 | tee ${LOG_DIR}/${LOG_FILE}
fi
32 changes: 19 additions & 13 deletions patrickstar/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch

import patrickstar.utils.global_timer as global_timer
from patrickstar.utils import logger, get_world_size, get_rank
from patrickstar.utils import logger, get_world_size, get_rank, log_dist
from .chunk_list import ChunkList, ChunkType
from .chunk_tensor_index import ChunkTensorIndex
from .const import AccessType, ChunkState, TensorState, TrainingStage
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self, rank: int, default_chunk_size: int, config=None):
self.opt_config["with_async_move"],
)
if self.opt_config["with_mem_cache"]:
print("[CONFIG] USING MEM CACHE")
logger.debug("[CONFIG] USING MEM CACHE")
self._time_profile = True

if torch.distributed.is_initialized():
Expand Down Expand Up @@ -229,7 +229,7 @@ def append_dummy_chunk(self, data_type: torch.dtype, chunk_type: ChunkType):
AccessType.DATA,
)

logger.info(
logger.debug(
f"Append a dummy chunk to the Chunk List {chunk_type} "
f"comm info {comm_info}"
)
Expand Down Expand Up @@ -464,7 +464,7 @@ def _fetch_remote_chunks(
"CLIENT_fetch_remote_chunks_allgather"
)

logger.info(f"rank {rank} allgather {chunk_id_list}")
logger.debug(f"rank {rank} allgather {chunk_id_list}")
torch.distributed.all_gather(
allgather_payload_buff,
self.chunk_list[local_chunk_id].payload,
Expand Down Expand Up @@ -924,7 +924,7 @@ def get_overall_chunk_size(self):
type_chunk_list,
) in self.chunk_tensor_index.chunk_type_to_chunk_id_list_map.items():

logger.info(f"Chunk list {type}")
logger.debug(f"Chunk list {type}")
for chunk_id in type_chunk_list:
chunk = self.chunk_list[chunk_id]
if self.opt_config["with_mem_saving_comm"] and chunk.is_dummy():
Expand All @@ -943,22 +943,25 @@ def get_overall_chunk_size(self):
return overall_size, overall_utilization_ratio

def display_chunk_info(self):
logger.info("Print chunk list info.")
logger.debug("Print chunk list info.")

overall_size = 0
overall_chunk_num = 0
overall_utilization_ratio = 0.0
max_utilization_ratio = 0
for (
type,
type_chunk_list,
) in self.chunk_tensor_index.chunk_type_to_chunk_id_list_map.items():
logger.info(f"Chunk list {type}")
logger.debug(f"Chunk list {type}")
for chunk_id in type_chunk_list:
chunk = self.chunk_list[chunk_id]
if self.opt_config["with_mem_saving_comm"] and chunk.is_dummy():
continue
comm_info = self.chunk_tensor_index.chunk_id_to_comm_info_map[chunk_id]
assert comm_info is not None

logger.info(
logger.debug(
f"Chunk id {chunk.chunk_id}, state {chunk.get_state()}, "
f"comm info {comm_info}, "
f"capacity {chunk.capacity / 1024 / 1024} M elems, "
Expand All @@ -975,15 +978,18 @@ def display_chunk_info(self):
f"tensor_id {info.tensor_id}, state {info.state()}, name {info.tensor_name}"
)
last_used_pos = max(last_used_pos, info.start_offset + info.numel)
logger.info(
logger.debug(
f"chunk used {last_used_pos/1024/1024} M elem, "
f"{last_used_pos/chunk.capacity * 100} %"
)
overall_utilization_ratio += last_used_pos / chunk.capacity
cur_util = last_used_pos / chunk.capacity
max_utilization_ratio = max(cur_util, max_utilization_ratio)
overall_utilization_ratio += cur_util
overall_size += chunk.get_chunk_space()
overall_chunk_num += 1

logger.info(f"OVERALL CHUNK SIZE {overall_size / 1024 / 1024 / 1024} GB")
logger.info(
f"OVERALL UTILIZATION {overall_utilization_ratio / overall_chunk_num} %"
log_dist(f"OVERALL CHUNK SIZE {overall_size / 1024 / 1024 / 1024} GB")
log_dist(
f"OVERALL UTILIZATION {overall_utilization_ratio / overall_chunk_num * 100} %"
)
log_dist(f"MAX UTILIZATION {max_utilization_ratio * 100} %")
34 changes: 21 additions & 13 deletions patrickstar/core/memtracer/memtracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from patrickstar.core.const import TrainingStage
from patrickstar.profiler import profiler
from patrickstar.utils import (
log_dist,
get_memory_info,
get_sys_memory_used,
get_world_size,
Expand Down Expand Up @@ -125,10 +126,6 @@ def __init__(self, local_rank: int = 0, config=None):
if self.use_async_mem_monitor:
self.async_mem_monitor = AsyncMemoryMonitor()

print(
f"[Mem Tracer] Using Asyn Mem Monitor Flag : {self.use_async_mem_monitor}"
)

mem_info = get_memory_info()
local_world_size = get_local_world_size()
if self.use_fake_dist:
Expand All @@ -150,7 +147,7 @@ def __init__(self, local_rank: int = 0, config=None):
mem_info.total * self._overall_cpu_mem_ratio / local_world_size
)

logger.info(
log_dist(
f"Init Manager over all gpu mem {self._overall_gpu_mem / 1e6} MB, "
f"cpu mem {self._overall_cpu_mem / 1e6} MB"
)
Expand All @@ -175,14 +172,14 @@ def close_tracer(self):
"""
if self.use_async_mem_monitor:
self.async_mem_monitor.finish()
print("**** Memory Tracer is closed! ****")
log_dist("**** Memory Tracer is closed! ****")

def start_train(self, param_fp16_chunk_size, chunk_size):
self._param_fp16_chunk_size = param_fp16_chunk_size
self._default_chunk_size = chunk_size
if self.use_async_mem_monitor:
self.async_mem_monitor.start()
print("**** Memory Tracer is stared! ****")
log_dist("**** Memory Tracer is stared! ****")

def update_margin_mem(self):
r"""Update the number of GPU free chunks for optimizer."""
Expand All @@ -193,6 +190,15 @@ def update_margin_mem(self):
max_gpu_sys_used = 0
else:
max_gpu_sys_used = max(self.gpu_sys_used_list)

if len(self.cpu_sys_used_list) == 0:
logger.warning(
"No gpu info collected. Maybe there are no chunk based tensors."
)
max_cpu_sys_used = 0
else:
max_cpu_sys_used = max(self.cpu_sys_used_list)

margin_mem_size = (
self._overall_gpu_mem - max_gpu_sys_used - self._param_fp16_chunk_size
)
Expand All @@ -201,14 +207,16 @@ def update_margin_mem(self):
(margin_mem_size) / (self._default_chunk_size * 12) * self._margin_use_ratio
)

logger.info("--------------- GPU INFO AFTER BWD ----------------")
logger.info(f"Max GPU System Mem (non-chunk) Used {max_gpu_sys_used / 1e6} MB")
logger.info(f"Param FP16 Chunk Size {self._param_fp16_chunk_size / 1e6} MB")
logger.info(
log_dist("--------------- GPU INFO AFTER BWD ----------------")
log_dist(f"Max GPU System Mem (non-chunk) Used {max_gpu_sys_used / 1e6} MB")
log_dist(f"Max CPU System Mem (non-chunk) Used {max_cpu_sys_used / 1e6} MB")
log_dist(f"Param FP16 Chunk Size {self._param_fp16_chunk_size / 1e6} MB")
log_dist(
f"Margin Mem Size {margin_mem_size / 1e6} MB, "
f"available chunk num for Optimizer States {self._margin_chunk_num_for_gpu_adam}"
)
logger.info(f"OVERALL GPU MEM {self._overall_gpu_mem}")
log_dist("--------------- GPU INFO AFTER BWD ----------------")
logger.debug(f"OVERALL GPU MEM {self._overall_gpu_mem/1024/1024} MB")

def reset_memory_stats(self):
"""
Expand All @@ -228,7 +236,7 @@ def reset_memory_stats(self):
self.gpu_used_list = []
self.gpu_chunk_used_list = []
self.gpu_sys_used_list = []
logger.info("Reset Memory Statistics")
log_dist("Reset Memory Statistics")

def get_margin_chunk_num_for_gpu_adam(self):
return self._margin_chunk_num_for_gpu_adam
Expand Down
8 changes: 3 additions & 5 deletions patrickstar/core/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from patrickstar.core import register_param, is_param_registered, ParamType
from patrickstar.manager import _runtime_config
from patrickstar.ops import Embedding
from patrickstar.utils import logger, print_rank, get_rank, get_world_size
from patrickstar.utils import logger, log_dist, print_rank, get_rank, get_world_size
from patrickstar.utils import see_memory_usage

_orig_torch_empty = torch.empty
Expand Down Expand Up @@ -251,7 +251,7 @@ def _post_context_exec(self):
number of processes.
3. Add a dummy param at the start of CPU Embedding for huggingface.
"""
logger.info("Post Model Init Context")
log_dist("Post Model Init Context")

def _origin_new(cls, *arg, **kwargs):
return object.__new__(cls)
Expand Down Expand Up @@ -333,7 +333,7 @@ def _origin_new(cls, *arg, **kwargs):
chunk_num += 1

world_size = get_world_size()
logger.info(f"param fp16 chunk num {chunk_num}")
log_dist(f"Param fp16 chunk num {chunk_num}")
while chunk_num % world_size != 0:
self.client.append_dummy_chunk(torch.half, ChunkType.PARAM_FP16)
chunk_num += 1
Expand Down Expand Up @@ -363,8 +363,6 @@ def _post_init_method(self, module):
self.client.torch_param_allreduce_list.append(param)
return

print_rank(f"Converting Params in {module.__class__.__name__}", force=False)

if not _runtime_config.use_chunk:
for name, param in module.named_parameters(recurse=False):
name = f"{module.__class__.__name__}.{name}_{self.param_idx}"
Expand Down
Loading

0 comments on commit ee5c996

Please sign in to comment.