From e7fd88ecce575f42630db8568f6561160a8eb660 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Thu, 8 Jan 2026 17:32:04 +0800 Subject: [PATCH 1/3] support flagscale v1.0.0 version --- .../runner/auto_tuner/simulator/README.md | 45 ++ .../simulator/analylize_pipeline_time.py | 501 ++++++++++++++++++ .../runner/auto_tuner/simulator/config_gen.py | 432 +++++++++++++++ .../custom_backend/include/dummy.hpp | 157 ++++++ .../simulator/custom_backend/setup.py | 25 + .../simulator/custom_backend/src/dummy.cpp | 285 ++++++++++ flagscale/train/datasets/sft_dataset.py | 230 ++++++++ flagscale/train/megatron/train_gpt.py | 2 + .../train/megatron/training/arguments.py | 4 +- .../train/megatron/training/arguments_fs.py | 75 ++- flagscale/train/megatron/training/training.py | 23 +- 11 files changed, 1759 insertions(+), 20 deletions(-) create mode 100644 flagscale/runner/auto_tuner/simulator/README.md create mode 100644 flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py create mode 100644 flagscale/runner/auto_tuner/simulator/config_gen.py create mode 100644 flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp create mode 100644 flagscale/runner/auto_tuner/simulator/custom_backend/setup.py create mode 100644 flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp create mode 100644 flagscale/train/datasets/sft_dataset.py diff --git a/flagscale/runner/auto_tuner/simulator/README.md b/flagscale/runner/auto_tuner/simulator/README.md new file mode 100644 index 0000000000..4ac43d38ce --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/README.md @@ -0,0 +1,45 @@ +# Environment +Begin at the root path of `FlagScale` repository: +1. Install backend +``` +cd flagscale/runner/auto_tuner/simulator/custom_backend/ +python setup.py develop +``` + +# Setup +2. Set necessary parameters in `config_gen.py`. For example: +``` +device_type_list = ["A", "B"] +device_num_list = [4, 4] +global_batch_size = 32 +num_micro_batches = 8 +num_layers = 4 +``` +# Run a Task +3. Start the auto-tuning: + a. set PYTHONPATH + ``` + export PYTHONPATH=/***/FlagScale:$PYTHONPATH + export PYTHONPATH=$PYTHONPATH:/***/FlagScale/third_party/Megatron-LM + + vim /***/FlagScale/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py + os.environ["PYTHONPATH"] = ( + "/***/FlagScale:" + "/***/FlagScale/third_party/Megatron-LM" + ) + ``` + b. run + + vim flagscale/runner/auto_tuner/simulator/config_gen.py + + set scheme = vpp or scheme = 1F1B + + python flagscale/runner/auto_tuner/simulator/config_gen.py + + c. result + ``` + {'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 8, 5, 5, 5, 1], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 57.52105478485333, 'theory_peak_memory': [110.487650304, 118.80914944, 158.35625472, 158.35625472, 158.35625472, 42.519842816], 'oom_error': True} + {'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 7, 5, 5, 5, 2], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 61.20105478485332, 'theory_peak_memory': [110.487650304, 109.345202176, 158.35625472, 158.35625472, 158.35625472, 61.447737344], 'oom_error': True} + {'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 8, 5, 5, 4, 2], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 54.73105478485331, 'theory_peak_memory': [110.487650304, 118.80914944, 158.35625472, 158.35625472, 119.365943296, 61.447737344], 'oom_error': True} +... +``` diff --git a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py new file mode 100644 index 0000000000..efe28be9cc --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py @@ -0,0 +1,501 @@ +import os +import re +import subprocess +import time +from collections import defaultdict +# from megatron.training import get_args + + +def kill_other_python_processes(): + current_pid = os.getpid() + clear_cmd = f"pkill -f python -o --signal TERM --ignore \"${current_pid}\"" + subprocess.run(clear_cmd, text=True, shell=True) + + +def compute_pipeline_parallelism_cost( + scheme: str = '1F1B', + # num_stages: int=1, + num_micro_batches: int = 1, + process_mesh: list = None, + pp_layers_split: list = None, + fwd_time_per_stage_chunk: list = None, + bwd_time_per_stage_chunk: list = None, + comm_time_between_stages: list = None, + vpp_partition: list = None, + # TODO: add fine-greaied recomputation +): + print(f"--- Compute Pipeline Cost ---") + + # process_mesh: [tp0,cp0,ep0,dp0,pp0,(tp1,cp1,...)] + # comm_time_between_stages[i] means the comm time between stage i-1 and stage i + num_pp_stages = sum(process_mesh[4::5]) + + assert ( + len(pp_layers_split) == num_pp_stages + ), "\flength of list {num_layers_per_stage} should match {num_stages}" + if scheme == 'vpp': + num_pp_stages = sum(vpp_partition) + + assert ( + len(fwd_time_per_stage_chunk) == num_pp_stages + ), "\flength of list {fwd_time_per_stage_chunk} should match {num_stages}" + assert ( + len(bwd_time_per_stage_chunk) == num_pp_stages + ), "\flength of list {bwd_time_per_stage_chunk} should match {num_stages}" + assert ( + len(comm_time_between_stages) == num_pp_stages + ), "\flength of list {comm_time_between_stages} should match {num_stages}" + + pp_last_stage_time = num_micro_batches * ( + fwd_time_per_stage_chunk[num_pp_stages - 1] + bwd_time_per_stage_chunk[num_pp_stages - 1] + ) + if num_pp_stages == 1: + return num_micro_batches * ( + fwd_time_per_stage_chunk[num_pp_stages - 1] + + bwd_time_per_stage_chunk[num_pp_stages - 1] + ) + + pipeline_cost = 0 + # TODO: consider when comm time > comp time + # each stage onlt depends on its next stage + if scheme == '1F1B' or scheme == 'AFAB': + pipeline_cost = pp_last_stage_time + for stage_from_last in range(2, num_pp_stages): + pp_this_stage_overlapped_time = (num_micro_batches - 1) * ( + fwd_time_per_stage_chunk[num_pp_stages - 1] + + bwd_time_per_stage_chunk[num_pp_stages - 1] + ) + pp_this_stage_compute_time = ( + fwd_time_per_stage_chunk[num_pp_stages - stage_from_last] + + bwd_time_per_stage_chunk[num_pp_stages - stage_from_last] + ) + pp_last_stage_overall_time = ( + pipeline_cost + 2 * comm_time_between_stages[num_pp_stages - stage_from_last + 1] + ) + # not consider the situation that comm stucks the comp + # which means the comm time should no more than the comp time(fwd time) + pipeline_cost = pp_this_stage_compute_time + max( + pp_last_stage_overall_time, pp_this_stage_overlapped_time + ) + # else: + # raise (ValueError("Scheme must be '1F1B' or 'AFAB'.")) + elif scheme == 'vpp': + num_vp_stages = len(fwd_time_per_stage_chunk) + num_pp_stages = len(comm_time_between_stages) # error + vstage_to_pp = [] + for i, count in enumerate(vpp_partition): + vstage_to_pp += [i] * count + + comm_per_vstage = [0.0] * num_vp_stages + for i in range(num_vp_stages - 1): + cur_pp, next_pp = vstage_to_pp[i], vstage_to_pp[i + 1] + if next_pp != cur_pp: + comm_per_vstage[i] = comm_time_between_stages[cur_pp + 1] + + vp_last_stage_time = num_micro_batches * ( + fwd_time_per_stage_chunk[-1] + bwd_time_per_stage_chunk[-1] + ) + pipeline_cost = vp_last_stage_time + for vp_from_last in range(2, num_vp_stages + 1): + this_vp_idx = num_vp_stages - vp_from_last + this_stage_fwd = fwd_time_per_stage_chunk[this_vp_idx] + this_stage_bwd = bwd_time_per_stage_chunk[this_vp_idx] + this_stage_compute_time = this_stage_fwd + this_stage_bwd + + pp_idx = this_vp_idx % num_pp_stages + comm_time = comm_time_between_stages[min(pp_idx + 1, num_pp_stages - 1)] + + this_vp_overlapped_time = (num_micro_batches - 1) * this_stage_compute_time + + last_vp_total_time = pipeline_cost + 2 * comm_time + + pipeline_cost = this_stage_compute_time + max( + this_vp_overlapped_time, last_vp_total_time + ) + + return pipeline_cost + + +import random + +LAYER_RE = re.compile(r"decoder\.layers\.(\d+)\.(.+)") + +def extract_stage_ops_from_raw_log(log_text: str): + layers = defaultdict(lambda: { + "has_attention": False, + "has_mlp": False, + "has_qkv": False, + "has_proj": False, + "has_fc1": False, + "has_fc2": False, + }) + + for raw_line in log_text.splitlines(): + line = raw_line.strip() + + if "decoder.layers." not in line: + continue + + m = LAYER_RE.search(line) + if not m: + continue + + layer_id = int(m.group(1)) + suffix = m.group(2) + + # attention + if "self_attention.linear_qkv" in suffix: + layers[layer_id]["has_attention"] = True + layers[layer_id]["has_qkv"] = True + + if "self_attention.linear_proj" in suffix: + layers[layer_id]["has_attention"] = True + layers[layer_id]["has_proj"] = True + + # mlp + if "mlp.linear_fc1" in suffix: + layers[layer_id]["has_mlp"] = True + layers[layer_id]["has_fc1"] = True + + if "mlp.linear_fc2" in suffix: + layers[layer_id]["has_mlp"] = True + layers[layer_id]["has_fc2"] = True + + return layers + +def tp_collectives_per_stage(layers, sequence_parallel=False): + total = 0 + per_layer = {} + + for layer_id, ops in layers.items(): + cnt = tp_collectives_per_layer( + has_attention=ops["has_attention"], + has_mlp=ops["has_mlp"], + sequence_parallel=sequence_parallel, + ) + per_layer[layer_id] = cnt + total += cnt + + return total, per_layer + +def tp_collectives_per_layer( + has_attention=True, + has_mlp=True, + sequence_parallel=False +): + cnt = 0 + if has_attention: + cnt += 1 # qkv backward + cnt += 2 # proj fwd + bwd + if has_mlp: + cnt += 1 # fc1 backward + cnt += 2 # fc2 fwd + bwd + if sequence_parallel: + cnt += 4 # ln fwd/bwd rs + ag + return cnt + +def ring_allreduce_time( + n_bytes, + N_ranks, + N_nodes, + alpha_base, + alpha_intra, + alpha_inter, + hops, + alpha_switch, + beta, +): + alpha_hw = ( + 2 * (N_ranks - N_nodes) * alpha_intra + + 2 * (N_nodes - 1) * (alpha_inter * hops * alpha_switch) + ) + + bw_term = 2 * (N_ranks - 1) / N_ranks * n_bytes * beta + + return alpha_base + alpha_hw + bw_term + +def stage_has_tp_from_process_mesh(process_mesh): + assert len(process_mesh) % 5 == 0 + + stage_has_tp = {} + stage_id = 0 + + for i in range(0, len(process_mesh), 5): + device = process_mesh[i:i+5] + tp = device[0] + pp = device[4] + + has_tp = tp > 1 + + for _ in range(pp): + stage_has_tp[stage_id] = has_tp + stage_id += 1 + + return stage_has_tp + + +def simulator( + process_mesh: list = None, + stage: int = 0, + num_layers: int = None, + simulated_rank: int = None, + pp_layers_split: list = None, +): + + # os.environ["PYTHONPATH"] = "/share/project/heyongzhe/FlagScale/megatron:/share/project/heyongzhe/FlagScale" + #os.environ["PYTHONPATH"] = ( + # "/workspace/single_process_simulator_nd/FlagScale:" + # "/workspace/single_process_simulator_nd/FlagScale/third_party/Megatron-LM" + #) + os.environ["ENABLE_SIMULATOR"] = "1" + os.environ["CUDA_VISIBLE_DEVICES"] = "3" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["RANK"] = str(simulated_rank) + os.environ["LOCAL_RANK"] = "0" + # os.environ["WORLD_SIZE"] = args.world_size + os.environ["WORLD_SIZE"] = "8" + # os.environ["WORLD_SIZE"] = "32" + rdav_endpoint = random.randint(0, 40000) + os.environ["RDZV_ENDPOINT"] = "localhost:" + str(rdav_endpoint) + # os.environ["RZDV_ENDPOINT"]="localhost:37832" + os.environ["RDZV_BACKEND"] = "c10d" + os.environ["MASTER_ADDR"] = "localhost" + + program_entry = " ./flagscale/train/megatron/train_aquila_sft.py " + simulation_arguments = " --enable-hetero --enable-simulator --distributed-backend dummy " + # fine_grained_recomputation_args = "--recompute-granularity-per-stage-micro-batch '[1, 1, 1]' --recompute-method-per-stage-micro-batch '[1, 1, 1]' --recompute-num-layers-per-stage-micro-batch '[1, 1, 1]'" + fine_grained_recomputation_args = "" + # print(stage) + + pp_layer_split_args = " --hetero-pipeline-layer-split " + for layers in pp_layers_split: + pp_layer_split_args = pp_layer_split_args + str(layers) + " " + + process_mesh_str = " --hetero-process-meshes " + for dim in process_mesh: + process_mesh_str = process_mesh_str + str(dim) + " " + + num_pp_stages = sum(process_mesh[4::5]) + pp_size_args = " --pipeline-model-parallel-size " + str(num_pp_stages) + " " + + # TODO: too ugly to show this command in the code, re-organize these parameters in another way later + train_command = ( + "python " + + program_entry + + "--tensor-model-parallel-size 1 --timing-log-level 2 --disable-bias-linear --use-flash-attn --sequence-parallel --use-distributed-optimizer --use-mcore-models --transformer-impl transformer_engine --hetero-device-types A800 A800 --hetero-current-device-type A800 --bf16 --attention-softmax-in-fp32 --accumulate-allreduce-grads-in-fp32 --log-interval 1 --log-throughput --tensorboard-log-interval 1 --wandb-project aquila2 --wandb-exp-name test --tensorboard-dir /share/project/heyongzhe/FlagScale/outputs/tensorboard --wandb-save-dir /share/project/heyongzhe/FlagScale/outputs/wandb --num-layers 32 --hidden-size 4096 --num-attention-heads 32 --seq-length 2048 --max-position-embeddings 2048 --norm-epsilon 1e-05 --use-rotary-position-embeddings --no-position-embedding --swiglu --multiple-of 256 --normalization RMSNorm --untie-embeddings-and-output-weights --init-method-std 0.0165 --attention-dropout 0.0 --hidden-dropout 0.0 --weight-decay 0.1 --clip-grad 1.0 --train-samples 128 --global-batch-size 64 --micro-batch-size 1 --seed 42 --lr 0.0002 --weight-decay 0.01 --adam-beta1 0.9 --adam-beta2 0.95 --lr 0.00015 --min-lr 1.5e-05 --lr-warmup-samples 0 --lr-decay-style cosine --data-path /workspace/FlagScale/datapath/pile_wikipedia_demo --split 1 --tokenizer-type AquilaTokenizerFS --vocab-file ./examples/aquila/tokenizer/vocab.json --merge-file ./examples/aquila/tokenizer/merges.txt --special-tokens-file ./examples/aquila/tokenizer/special_tokens.txt --vocab-size 100008 " + + process_mesh_str + + simulation_arguments + + pp_layer_split_args + + fine_grained_recomputation_args + + pp_size_args + ) + + # enough sleeping time is needed to really kill the survival megatron process + # as least 5 sec before & after killing can not succeed every time + print("sleeping...") + # print(train_command) + # time.sleep(10) + kill_other_python_processes() + # time.sleep(10) + print("start...") + result = subprocess.run(train_command, capture_output=True, text=True, shell=True) + print(result) + output = result.stdout.strip() + print(train_command) + + print("------------------------------output1--------------------------------------") + print(output) + print("--------------------------------------------------------------------") + # example output: "[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12" + match = re.search(r"forward:\s*([\d.]+),\s*backward:\s*([\d.]+)", output) + s_out = extract_stage_ops_from_raw_log(output) + reduce_op_cnt = tp_collectives_per_stage(s_out)[0] + + + n_bytes = 16.7 + n_rank = 2 + n_nodes = 1 + alpha_base = 5e-6 + alpha_intra = 1e-6 + alpha_inter = 3e-6 + hops = 1 if n_nodes > 1 else 0 + alpha_switch = 1.2 + beta = 1 / (25 * 1024**3) + + time = ring_allreduce_time(n_bytes, n_rank, n_nodes, alpha_base, alpha_intra, alpha_inter, hops, alpha_switch, beta) + + fw_cm_time = time * reduce_op_cnt * 0.33 + bw_cm_time = time * reduce_op_cnt * 0.66 + stp = stage_has_tp_from_process_mesh(process_mesh) + + + if match: + fwd_time = float(match.group(1)) + bwd_time = float(match.group(2)) + # comm_time = float(match.group(3)) + comm_time = estimate_comm_time_between_stages(1, 2048, 4096) + if stp[stage]: + fwd_time += fw_cm_time + bwd_time += bw_cm_time + print("forward:", fwd_time) + print("backward:", bwd_time) + print("communication:", comm_time) + else: + raise ( + ValueError( + "Results not found. Example output: \"[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12\"" + ) + ) + return fwd_time, bwd_time, comm_time + + +def compute_vpp_from_layers( + pp_layers_split, target_layers_per_vstage=2, device_speed=None, min_layers_per_virtual_stage=2 +): + """ + Args: + pp_layers_split: list[int] + target_layers_per_vstage: int + device_speed: list[float] + min_layers_per_virtual_stage: + Returns: + vpp_list: list[int], + """ + vpp_list = [] + max_speed = max(device_speed) if device_speed else 1.0 + + for i, num_layers in enumerate(pp_layers_split): + base_vpp = max(1, round(num_layers / target_layers_per_vstage)) + + if device_speed: + scale = device_speed[i] / max_speed + base_vpp = max(1, round(base_vpp * scale)) + + base_vpp = min(base_vpp, num_layers // min_layers_per_virtual_stage) + if base_vpp == 0: + base_vpp = 1 + + while num_layers % base_vpp != 0 and base_vpp > 1: + base_vpp -= 1 + + vpp_list.append(base_vpp) + + return vpp_list + + +def estimate_comm_time_between_stages( + batch_size: int, + seq_len: int, + hidden_size: int, + dtype_bytes: int = 2, # bf16 + bandwidth_GBps: float = 300.0, + latency_ms: float = 0.01, + tensor_model_parallel_size: int = 1, + virtual_pipeline_size: int = 1, + activation_fraction: float = 1.0, + use_allgather_for_activation: bool = False, +): + bytes_one_way = batch_size * seq_len * hidden_size * dtype_bytes * activation_fraction + if tensor_model_parallel_size > 1: + bytes_one_way /= tensor_model_parallel_size + + K = max(1, virtual_pipeline_size) + per_transfer_bytes = bytes_one_way / K + bw_Bps = bandwidth_GBps * 1e9 + one_way_time = per_transfer_bytes / bw_Bps + latency_ms / 1000.0 + comm_time = 2 * K * one_way_time # fwd+bwd + + if use_allgather_for_activation and tensor_model_parallel_size > 1: + extra_bytes = ( + (tensor_model_parallel_size - 1) / tensor_model_parallel_size + ) * bytes_one_way + comm_time += extra_bytes / bw_Bps + (latency_ms / 1000.0) + return comm_time + + +# call simulator to obtain the execution of each stage +def simulate_pipeline_parallelism_per_stage_time( + process_mesh: list = None, + pp_layers_split: list = None, + scheme: str = '1F1B', + fwd_time_per_stage_chunk: list = None, + bwd_time_per_stage_chunk: list = None, + comm_time_between_stages: list = None, +): + print(f"--- Simulation Begin ---") + print(f"Process Mesh: {process_mesh}") + print(f"PP Layer Split: {pp_layers_split}") + if scheme == '1F1B': + for stage, num_layers in enumerate(pp_layers_split): + # TODO: confirm simulated_rank for different stage + print(f"Stage: {stage}; Num Layers: {num_layers}") + simulated_rank = stage + try: + fwd_time, bwd_time, comm_time = simulator( + process_mesh, stage, num_layers, simulated_rank, pp_layers_split + ) + fwd_time_per_stage_chunk.append(fwd_time) + bwd_time_per_stage_chunk.append(bwd_time) + comm_time_between_stages.append(comm_time) + except Exception as e: + print(f"[Error] Simulator failed at stage {stage}, skip. Reason: {e}") + continue + + elif scheme == 'vpp': + vpp_list = compute_vpp_from_layers(pp_layers_split) + print(vpp_list) + for stage_idx, (num_layers, vpp) in enumerate(zip(pp_layers_split, vpp_list)): + layers_per_chunk = num_layers // vpp + for vstage_idx in range(vpp): + vstage_name = f"{stage_idx}-{vstage_idx}" + print(f" ->Stage {vstage_name} : ( {layers_per_chunk})") + try: + fwd_time, bwd_time, comm_time = simulator( + process_mesh=process_mesh, + stage=vstage_name, + num_layers=layers_per_chunk, + simulated_rank=stage_idx, + pp_layers_split=pp_layers_split, + ) + fwd_time_per_stage_chunk.append(fwd_time) + bwd_time_per_stage_chunk.append(bwd_time) + comm_time_between_stages.append(comm_time) + except Exception as e: + print(f"[Error] Simulator failed at V-stage {vstage_name}, skip. Reason: {e}") + continue + + print(f"--- Simulation End ---") + + +def analyze_pp_time( + scheme: str = '1F1B', + num_micro_batches: int = 1, + process_mesh: list = None, + pp_layers_split: list = None, +): + fwd_time_per_stage_chunk = [] + bwd_time_per_stage_chunk = [] + comm_time_between_stages = [] + vpp_partition = compute_vpp_from_layers(pp_layers_split) + + simulate_pipeline_parallelism_per_stage_time( + process_mesh=process_mesh, + pp_layers_split=pp_layers_split, + scheme=scheme, + fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, + bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, + comm_time_between_stages=comm_time_between_stages, + ) + + pipeline_cost = compute_pipeline_parallelism_cost( + scheme=scheme, + num_micro_batches=num_micro_batches, + process_mesh=process_mesh, + pp_layers_split=pp_layers_split, + fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, + bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, + comm_time_between_stages=comm_time_between_stages, + vpp_partition=vpp_partition, + ) + + return pipeline_cost diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py new file mode 100644 index 0000000000..c5b81c2047 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -0,0 +1,432 @@ +import ast +import json +import os + +from functools import reduce +from itertools import combinations, product + +import analylize_pipeline_time + +# from itertools import product +#import flagscale.train.theoretical_memory_usage as mem_usg +import flagscale.train.megatron.training.theoretical_memory_usage as mem_usg + +BYTES_OF_GB = 10**9 + + +device_type_list = ["A800", "A800"] +device_num_list = [4, 4] +memory_capacity_of_devices = [80, 80] # GB + +global_batch_size = 512 +num_micro_batches = 8 +num_layers = 32 + +num_gpus = sum(device_num_list) + + +class DevicesInfo: + def __init__(self, device_type_list: list, device_num_list: list): + assert len(device_type_list) == len( + device_num_list + ), "\flength of list {device_type_list} should match {device_num_list}" + self.device_type_list = device_type_list + self.device_num_list = device_num_list + self.device_types_count = len(device_type_list) + self.possible_parallelisms = [] + + +class HeteroConfig: + def __init__( + self, + mesh, + device_types, + pp_layer_split, + recompute_granularity=None, + recompute_method="uniform", + recompute_num_layers=1, + theory_peak_memory=0.0, + oom_error=False, + ): + self.mesh = mesh + self.device_types = device_types + self.pp_layer_split = pp_layer_split + # self.micro_batch_size = 1 + self.recompute_granularity = recompute_granularity + self.recompute_method = recompute_method + self.recompute_num_layers = recompute_num_layers + + self.simulated_time = 0.0 + self.theory_peak_memory = theory_peak_memory + self.oom_error = oom_error + + +def generate_hetero_meshes( + devices_info: DevicesInfo, + global_batch_size: int = None, + num_layers: int = None, + output_file: str = "results.json", +): + def enumerate_parallelism(device_num: int = None): + possible_parallelisms = [] + for tp in range(1, device_num + 1): + for dp in range(1, device_num // tp + 1): + if device_num % (dp * tp) == 0: + pp = device_num // (dp * tp) + # mesh: [tp, cp, ep, dp, pp] + possible_parallelisms.append([tp, 1, 1, dp, pp]) + return possible_parallelisms + + def is_legal_combination(comb: list): + pp = sum(comb[4::5]) + # check dp is legal + max_dp = global_batch_size // pp + for dp in comb[3::5]: + if max_dp % dp != 0: + return False + for i in range(len(comb) // 5): + tp, _, _, dp, pp = comb[i * 5 : i * 5 + 5] + device_num = devices_info.device_num_list[i] + if tp * dp * pp != device_num: + return False + return True + + def is_extreme_strategy(comb: list): + for mesh_index in range(len(comb) // 5): + num_devices_in_mesh = reduce( + lambda x, y: x * y, comb[mesh_index * 5 : mesh_index * 5 + 5] + ) + dp_size_in_mesh = comb[mesh_index * 5 + 3] + tp_size_in_mesh = comb[mesh_index * 5 + 0] + pp_size_in_mesh = comb[mesh_index * 5 + 4] + print( + mesh_index, + comb[mesh_index * 5 : mesh_index * 5 + 5], + num_devices_in_mesh, + dp_size_in_mesh, + tp_size_in_mesh, + pp_size_in_mesh, + ) + if ( + pp_size_in_mesh > num_devices_in_mesh // 2 + or tp_size_in_mesh > 8 + or dp_size_in_mesh > num_devices_in_mesh / 4 + ): + return True + else: + return False + + def combine_possible_parallelisms(possible_parallelisms, output_file): + '''Combine and filter results, writing them to a file to avoid OOM.''' + all_combinations = product(*possible_parallelisms) + with open(output_file, "w") as f: + for comb in all_combinations: + result = sum(comb, []) + if is_legal_combination(result): + if not is_extreme_strategy(result): + f.write(",".join(map(str, result)) + "\n") + + # Ensure output file does not exist initially + if os.path.exists(output_file): + os.remove(output_file) + + # Enumerate all possible meshes for each kind of device + for i in range(devices_info.device_types_count): + device_num = devices_info.device_num_list[i] + devices_info.possible_parallelisms.append(enumerate_parallelism(device_num)) + + # Combine possibilities and write results to file + combine_possible_parallelisms(devices_info.possible_parallelisms, output_file) + print(f"Results written to {output_file}") + + +def extract_mesh_stage_structure(mesh): + stage_counts = [] + for i in range(0, len(mesh), 5): + stage_counts.append(mesh[i + 4]) + return stage_counts + + +def split_layers(num_layers, pp_stages, mesh): + results = [] + mesh_stage_counts = extract_mesh_stage_structure(mesh) + # print(pp_stages) + for split_points in combinations(range(1, num_layers), pp_stages - 1): + # print(split_points) + if len(split_points) == 0: + continue + splits = ( + [split_points[0]] + + [split_points[i] - split_points[i - 1] for i in range(1, len(split_points))] + + [num_layers - split_points[-1]] + ) + # to prune some extreme splits + stage_index = 0 + mesh_total_layers = [] + violate = False + + for m in mesh_stage_counts: + sub_splits = splits[stage_index : stage_index + m] + stage_index += m + + if not sub_splits: + continue + + if max(sub_splits) - min(sub_splits) > 4: + violate = True + break + + mesh_total_layers.append(sum(sub_splits)) + + if violate: + continue + + if max(mesh_total_layers) - min(mesh_total_layers) > 4: + continue + + results.append(splits) + + return results + + +class MeshArguments: + def __init__(self, mesh_config: HeteroConfig): + # [tp, cp, ep, dp, pp] + self.data_parallel_size = mesh_config.mesh[3] + # TODO: pp size not correct when computing memory, because former method divides the layers evenly + # no embed and dropout for stages except the 1st, and make the layers changable + + # if args.pipeline_model_parallel_size > 1: + # activation_memory = ( + # perlayer_activation + # * args.num_layers + # / args.pipeline_model_parallel_size + # * in_flight_microbatches + # + embedding_activation_memory + # + dropout_activation_memory + # ) + # else: + # activation_memory = ( + # perlayer_activation * args.num_layers + # + embedding_activation_memory + # + dropout_activation_memory + # + output_layer_and_loss_activation_memory + # ) + self.pipeline_model_parallel_size = sum(mesh_config.mesh[4::5]) + self.tensor_model_parallel_size = mesh_config.mesh[0] + self.virtual_pipeline_model_parallel_size = None + self.num_experts = None + + self.context_parallel_size = 1 + self.swiglu = True + self.micro_batch_size = global_batch_size / num_micro_batches / self.data_parallel_size + self.num_layers = num_layers + self.num_attention_heads = 32 + self.group_query_attention = None # TODO + self.num_query_groups = 1 # TODO + # self.moe_layer_freq = 2 + # self.moe_router_topk = 1 + self.multi_latent_attention = False + self.seq_length = 2048 + self.padded_vocab_size = 4096 # TODO + self.hidden_size = 4096 + self.qk_layernorm = False + self.mtp_num_layers = None + self.expert_model_parallel_size = 1 + self.world_size = 8 + self.moe_shared_expert_intermediate_size = None + self.moe_ffn_hidden_size = None + ## self.ffn_hidden_size + self.multiple_of = 256 + hidden_dim = int(4 * self.hidden_size * 2 / 3) + self.ffn_hidden_size = self.multiple_of * ( + (hidden_dim + self.multiple_of - 1) // self.multiple_of + ) + # self.kv_channels + self.kv_channels = self.hidden_size // self.num_attention_heads + + self.recompute_granularity = mesh_config.recompute_granularity + self.recompute_method = mesh_config.recompute_method + self.recompute_num_layers = mesh_config.recompute_num_layers + + self.expert_tensor_parallel_size = 1 + + self.use_flash_attn = True + self.sequence_parallel = True + self.use_distributed_optimizer = True + self.untie_embeddings_and_output_weights = False # TODO + + self.enable_hetero = True + + +def report_oom_error( + memory_capacity_of_devices: list, meshes_config: list, peak_memory_usage_per_stage: list +): + stage_index = 0 + for mesh_index, num_stages_in_current_mesh in enumerate(meshes_config[4::5]): + for i in range(num_stages_in_current_mesh): + if ( + peak_memory_usage_per_stage[stage_index + i] + >= memory_capacity_of_devices[mesh_index] + ): + return True + stage_index = stage_index + num_stages_in_current_mesh + return False + + +def calculate_peak_memory_per_stage(mesh_config): + peak_memory_usage_per_stage = [] + model_parallel_training_args = MeshArguments(mesh_config) + stage_index = 0 + mesh_index = 0 + for pp_stage_num_per_mesh in mesh_config.mesh[4::5]: + model_parallel_training_args.data_parallel_size = mesh_config.mesh[3 + 5 * mesh_index] + model_parallel_training_args.tensor_model_parallel_size = mesh_config.mesh[ + 0 + 5 * mesh_index + ] + for stage in range(pp_stage_num_per_mesh): + model_parallel_training_args.num_layers = mesh_config.pp_layer_split[stage_index] + + peak_activation_memory_usage = mem_usg.compute_activation_memory( + args=model_parallel_training_args, num_microbatches=num_micro_batches + ) + peak_weight_optimizer_usage = mem_usg.compute_weight_and_optimizer_memory( + args=model_parallel_training_args + ) + peak_memory_usage = peak_activation_memory_usage + peak_weight_optimizer_usage + + peak_memory_usage_per_stage.append(peak_memory_usage / BYTES_OF_GB) + stage_index = stage_index + 1 + + mesh_index = mesh_index + 1 + + return peak_memory_usage_per_stage + + +def gen_hetero_configs( + device_type_list, + device_num_list, + global_batch_size, + num_layers, + # num_micro_batches, + # hetero_configs: list, + output_config_file: str = "hetero_configs.json", +): + devices_info = DevicesInfo(device_type_list=device_type_list, device_num_list=device_num_list) + + generate_hetero_meshes( + devices_info=devices_info, + global_batch_size=global_batch_size, + num_layers=num_layers, + output_file="results.json", + ) + + hetero_meshes = [] + with open("results.json", "r") as f: + for line in f: + hetero_meshes.append(list(map(int, line.strip().split(",")))) + # print(hetero_meshes) + # assert False + seen = set() + with open(output_config_file, "w") as config_file: # 打开输出文件 + for mesh in hetero_meshes: + pp_stages = sum(mesh[4::5]) + # in order to prune the num of layers in each stage to even number + pp_layer_splits = split_layers(num_layers=num_layers, pp_stages=pp_stages, mesh=mesh) + for split in pp_layer_splits: + split = [x for x in split] + hetero_config = HeteroConfig( + mesh=mesh, pp_layer_split=split, device_types=device_type_list + ) + # hetero_configs.append(hetero_config) + theory_peak_memory_per_stage = calculate_peak_memory_per_stage(hetero_config) + oom_error = report_oom_error( + memory_capacity_of_devices=memory_capacity_of_devices, + meshes_config=mesh, + peak_memory_usage_per_stage=theory_peak_memory_per_stage, + ) + # if oom_error: + # continue + + key = (tuple(mesh), tuple(split), tuple(device_type_list)) + if key in seen: + continue # 跳过重复项 + seen.add(key) + + config_data = { + "mesh": hetero_config.mesh, + "device_types": hetero_config.device_types, + "pp_layer_split": hetero_config.pp_layer_split, + "recompute_granularity": hetero_config.recompute_granularity, + "recompute_method": hetero_config.recompute_method, + "recompute_num_layers": hetero_config.recompute_num_layers, + "simulated_time": hetero_config.simulated_time, + "theory_peak_memory": theory_peak_memory_per_stage, + "oom_error": oom_error, + } + config_file.write(f"{config_data}\n") + + print(f"Hetero configurations saved to {output_config_file}") + + +import ast +import json + + +def read_configs_from_json(file_path: str): + configs_list = [] + with open(file_path, "r") as file: + for line in file: + # config_data = json.loads(line.strip()) + config_data = ast.literal_eval(line.strip()) + configs_list.append(config_data) + return configs_list + + +def get_min_simulated_time_config(hetero_configs): + if not hetero_configs: + return None + return min(hetero_configs, key=lambda x: x.get("simulated_time", float("inf"))) + + +def append_config_to_file(file_path: str, config: dict): + with open(file_path, "a") as f: + f.write(str(config) + "\n") + + +# for test and usage +if __name__ == "__main__": + # hetero_configs = [] + + # generate all possible and legal mesh configs, each element of hetero_configs is a mesh list + + gen_hetero_configs( + device_type_list=device_type_list, + device_num_list=device_num_list, + global_batch_size=global_batch_size, + num_layers=num_layers, + output_config_file="hetero_configs.json", + # num_micro_batches=num_micro_batches, + # hetero_configs=hetero_configs + ) + + # assert False + # simulation + file_path = "hetero_configs.json" + result_path = "simulate_time.json" + hetero_configs = read_configs_from_json(file_path) + for hetero_config in hetero_configs: + print(hetero_config) + pp_cost = hetero_config['simulated_time'] = analylize_pipeline_time.analyze_pp_time( + # pp_cost = hetero_config.simulated_time = analylize_pipeline_time.analyze_pp_time( + #scheme="vpp", + scheme="1F1B", + num_micro_batches=num_micro_batches, + process_mesh=hetero_config['mesh'], + pp_layers_split=hetero_config['pp_layer_split'], + ) + print(f"pipeline cost: {pp_cost}") + append_config_to_file(result_path, hetero_config) + + best_config = get_min_simulated_time_config(hetero_configs) + print(best_config) diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp b/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp new file mode 100644 index 0000000000..a71eb8536a --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include + +using AnyType = std::variant; + + +namespace c10d { + +class ProcessGroup; // 假设的类 +class Store; // 假设的类 + +class BackendDummy : public Backend { + public: + + BackendDummy(int rank, int size); + + const std::string getBackendName() const override; + void startCoalescing() override; + c10::intrusive_ptr endCoalescing() override; + +c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputTensors, + std::vector& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + +c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputTensors/* outputs */, + std::vector& inputTensors/* inputs */, + const AllgatherOptions& /* opts */ = AllgatherOptions()) override; + +c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensors/* outputBuffer */, + at::Tensor& inputTensors/* inputBuffer */, + const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) override; + +c10::intrusive_ptr broadcast( + std::vector &data, + const BroadcastOptions &opts = BroadcastOptions()) override; + +c10::intrusive_ptr allreduce( + std::vector &tensors, + const AllreduceOptions &opts = AllreduceOptions()) override; + +c10::intrusive_ptr allreduce_coalesced( + std::vector &tensors, + const AllreduceCoalescedOptions &opts = + AllreduceCoalescedOptions()) override; + +c10::intrusive_ptr reduce( + std::vector &tensors, + const ReduceOptions &opts = ReduceOptions()) override; + +c10::intrusive_ptr all_gather_object( + std::vector &outputTensors, + AnyType &inputTensors, + const AllgatherOptions &opts = AllgatherOptions()); + +c10::intrusive_ptr allgather( + std::vector> &outputTensors, + std::vector &inputTensors, + const AllgatherOptions &opts = AllgatherOptions()) override; + +c10::intrusive_ptr _allgather_base( + at::Tensor &outputBuffer, + at::Tensor &inputBuffer, + const AllgatherOptions &opts = AllgatherOptions()) override; + +c10::intrusive_ptr barrier( + const BarrierOptions &opts = BarrierOptions()) override; + +c10::intrusive_ptr gather( + std::vector> &outputTensors, + std::vector &inputTensors, + const GatherOptions &opts = GatherOptions()) override; + +c10::intrusive_ptr scatter( + std::vector &outputTensors, + std::vector> &inputTensors, + const ScatterOptions &opts = ScatterOptions()) override; + +c10::intrusive_ptr reduce_scatter( + std::vector &outputTensors, + std::vector> &inputTensors, + const ReduceScatterOptions &opts = ReduceScatterOptions()) override; + +c10::intrusive_ptr alltoall_base( + at::Tensor &outputTensor, + at::Tensor &inputTensor, + std::vector &outputSplitSizes, + std::vector &inputSplitSizes, + const AllToAllOptions &opts = AllToAllOptions()) override; + +c10::intrusive_ptr alltoall( + std::vector &outputTensors, + std::vector &inputTensors, + const AllToAllOptions &opts = AllToAllOptions()) override; + +c10::intrusive_ptr send( + std::vector &tensors, + int dstRank, + int tag) override; + +c10::intrusive_ptr recv( + std::vector &tensors, + int srcRank, + int tag) override; + +c10::intrusive_ptr recvAnysource( + std::vector &tensors, + int tag) override; + +static c10::intrusive_ptr createBackendDummy( + const c10::intrusive_ptr<::c10d::Store> &store, + int rank, + int size, + const std::chrono::duration &timeout); + +static void BackendDummyConstructor() __attribute__((constructor)) +{ + py::object module = py::module::import("torch.distributed"); + py::object register_backend = + module.attr("Backend").attr("register_backend"); + register_backend("dummy", py::cpp_function(createBackendDummy)); + } +}; + +class WorkDummy : public Work { + friend class BackendDummy; +public: + WorkDummy( + OpType opType, + c10::intrusive_ptr future) // future of the output + : Work( + -1, // rank, only used by recvAnySource, irrelevant in this demo + opType), + future_(std::move(future)) {} + bool isCompleted() override; + bool isSuccess() const override; + bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; + virtual c10::intrusive_ptr getFuture() override; + +private: + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py b/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py new file mode 100644 index 0000000000..172d5ad0e6 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py @@ -0,0 +1,25 @@ +import os + +import torch + +from setuptools import setup +from torch.utils import cpp_extension + +sources = ["src/dummy.cpp"] +include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"] + +if torch.cuda.is_available(): + module = cpp_extension.CUDAExtension( + name="dummy_collectives", sources=sources, include_dirs=include_dirs + ) +else: + module = cpp_extension.CppExtension( + name="dummy_collectives", sources=sources, include_dirs=include_dirs + ) + +setup( + name="Dummy-Collectives", + version="0.0.1", + ext_modules=[module], + cmdclass={'build_ext': cpp_extension.BuildExtension}, +) diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp b/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp new file mode 100644 index 0000000000..231ef1b1e7 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp @@ -0,0 +1,285 @@ +#include "dummy.hpp" +#include +// #include +// #include +// #include +// #include + +namespace c10d { + + +bool WorkDummy::isCompleted() { + return true; +} + +bool WorkDummy::isSuccess() const { + return true; +} + +bool WorkDummy::wait(std::chrono::milliseconds /* unused */) { + return true; +} + +c10::intrusive_ptr WorkDummy::getFuture() { + return future_; +} + +// If necessary, pass store/rank/size to the ctor and exchange connection +// information here +BackendDummy::BackendDummy(int rank, int size) + : Backend(rank, size) {} + +const std::string BackendDummy::getBackendName() const{ + return "dummy"; +} + +void BackendDummy::startCoalescing(){ + return; + } + +c10::intrusive_ptr BackendDummy::endCoalescing(){ + at::Tensor outputTensors; + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter_tensor_coalesced( + std::vector& outputTensors, + std::vector& inputTensors, + const ReduceScatterOptions&) { + // printf("dummy reduce_scatter_tensor_coalesced\n"); + for (auto& outputTensor : outputTensors) { + outputTensor.fill_(1); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allgather_into_tensor_coalesced( + std::vector& outputTensors/* outputs */, + std::vector& inputTensors/* inputs */, + const AllgatherOptions& ) { + // printf("dummy reduce_scatter_tensor_coalesced\n"); + for (auto& outputTensor : outputTensors) { + outputTensor.fill_(1); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::_reduce_scatter_base( + at::Tensor& outputTensors/* outputBuffer */, + at::Tensor& inputTensors/* inputBuffer */, + const ReduceScatterOptions& ) { + // printf("dummy _reduce_scatter_base\n"); + // for (auto& outputTensor : outputTensors) { + // outputTensor.fill_(1); + // } + outputTensors.fill_(1); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +// This is a dummy allgather that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& /* unused */) { + // printf("dummy allgather\n"); + for (auto& outputTensorVec : outputTensors) { + for (auto& outputTensor : outputTensorVec) { + outputTensor.fill_(1); + } + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::all_gather_object( + std::vector& outputTensors, + AnyType& inputTensors, + const AllgatherOptions& /* unused */) { + // printf("dummy all_gather_object Begin\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::_allgather_base( + at::Tensor& /* unused */, + at::Tensor& /* unused */, + const AllgatherOptions& /* unused */) { + // printf("dummy _allgather_base\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +// This is a dummy allreduce that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + // printf("dummy allreduce\n"); + for (auto& tensor : tensors) { + tensor.zero_(); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allreduce_coalesced( + std::vector& /* unused */, + const AllreduceCoalescedOptions& /* unused */) { + // printf("dummy allreduce_coalesced\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::alltoall( + std::vector& /* unused */, + std::vector& /* unused */, + const AllToAllOptions& /* unused */) { + // printf("dummy alltoall\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + // printf("dummy alltoall_base\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::barrier( + const BarrierOptions& /* unused */) { + // printf("dummy barrier\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + // printf("dummy broadcast\n"); + for (auto& tensor : tensors) { + tensor.zero_(); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::gather( + std::vector>& /* unused */, + std::vector& /* unused */, + const GatherOptions& /* unused */) { + // printf("dummy gather\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce( + std::vector& /* unused */, + const ReduceOptions& /* unused */) { + // printf("dummy reduce\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ReduceScatterOptions& /* unused */) { + // printf("dummy reduce_scatter\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ScatterOptions& /* unused */) { + // printf("dummy scatter\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::send( + std::vector& tensors, + int dstRank, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::recv( + std::vector& tensors, + int srcRank, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::recvAnysource( + std::vector& tensors, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::createBackendDummy( + const c10::intrusive_ptr<::c10d::Store>& /* unused */, + int rank, + int size, + const std::chrono::duration& /* unused */) { + return c10::make_intrusive(rank, size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("createBackendDummy", &BackendDummy::createBackendDummy); +} + +} // namespace c10d diff --git a/flagscale/train/datasets/sft_dataset.py b/flagscale/train/datasets/sft_dataset.py new file mode 100644 index 0000000000..40f6404092 --- /dev/null +++ b/flagscale/train/datasets/sft_dataset.py @@ -0,0 +1,230 @@ +# Copyright (c) 2024, BAAI. All rights reserved. + +import logging +import os +import time + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import numpy +import torch + +from megatron.core.datasets.gpt_dataset import ( + GPTDataset, + GPTDatasetConfig, + _get_ltor_masks_and_position_ids, +) +from megatron.core.datasets.indexed_dataset import IndexedDataset, get_bin_path, get_idx_path +from megatron.core.datasets.utils import Split + +logger = logging.getLogger(__name__) + + +@dataclass +class SFTDatasetConfig(GPTDatasetConfig): + """Configuration object for Megatron Core SFT datasets""" + + apply_sft_dataset_separated_loss_mask_if_existed: bool = None + """Option to apply separated loss mask files""" + + +class SFTDataset(GPTDataset): + """The base GPT dataset + + Args: + indexed_dataset (IndexedDataset): The IndexedDataset around which to build the GPTDataset + + dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping + + indexed_indices (numpy.ndarray): The set of the documents indices to expose + + num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch. + + index_split (Split): The indexed_indices Split + + config (GPTDatasetConfig): The config + """ + + def __init__( + self, + indexed_dataset: IndexedDataset, + dataset_path: Optional[str], + indexed_indices: numpy.ndarray, + num_samples: Optional[int], + index_split: Split, + config: SFTDatasetConfig, + ) -> None: + self.config = config + self.apply_sft_dataset_separated_loss_mask_if_existed = ( + config.apply_sft_dataset_separated_loss_mask_if_existed + ) + self.loss_mask_dataset = None + + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) + + self._build_loss_mask_dataset() + + def _build_loss_mask_dataset(self) -> None: + """ + Load Loss Mask IndexedDataset + """ + path_prefix = None + base_prefix = "_text_document" + loss_mask_prefix = "_loss_mask_document" + if self.dataset_path.endswith(base_prefix): + path_prefix = self.dataset_path[: -len(base_prefix)] + loss_mask_prefix + if self.apply_sft_dataset_separated_loss_mask_if_existed and path_prefix: + idx_path = get_idx_path(path_prefix) + bin_path = get_bin_path(path_prefix) + if os.path.exists(idx_path) and os.path.exists(bin_path): + self.loss_mask_dataset = IndexedDataset( + path_prefix, multimodal=False, mmap=self.config.mmap_bin_files + ) + + print(f"> Used Dataset: aux_loss_mask ...") + if self.loss_mask_dataset is not None: + assert len(self.dataset) == len( + self.loss_mask_dataset + ), f"Samples are not equal, ({len(self.dataset)} != {len(self.loss_mask_dataset)})" + + def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: + """Abstract method implementation + + Args: + idx (Optioal[int]): The index into the dataset + + Returns: + Dict[str, torch.Tensor]: The sample information wrapped in a dictionary + """ + if idx is None: + # Batch padding sequence so the index does not matter + text, _ = self._query_document_sample_shuffle_indices(0) + else: + text, _ = self._query_document_sample_shuffle_indices(idx) + + text = torch.from_numpy(text).long() + if self.config.add_extra_token_to_sequence: + tokens = text[:-1].contiguous() + labels = text[1:].contiguous() + else: + tokens = text + labels = torch.roll(text, shifts=-1, dims=0) + labels[-1] = self._pad_token_id + + if ( + not self.masks_and_position_ids_are_cacheable + or not self.masks_and_position_ids_are_cached + ): + attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids( + tokens, + self.config.tokenizer.eod, + self.config.reset_position_ids, + self.config.reset_attention_mask, + self.config.eod_mask_loss, + self.config.create_attention_mask, + ) + if self.masks_and_position_ids_are_cacheable: + self.cached_attention_mask = attention_mask + self.cached_loss_mask = loss_mask + self.cached_position_ids = position_ids + self.masks_and_position_ids_are_cached = True + else: + attention_mask = self.cached_attention_mask + loss_mask = self.cached_loss_mask + position_ids = self.cached_position_ids + + # For padded sequences, mask the loss + loss_mask[labels == self._pad_token_id] = 0.0 + + # For padded sequences, ensure the embedding layer can map the token ID + tokens[tokens == self._pad_token_id] = 0 + labels[labels == self._pad_token_id] = 0 + + # Batch padding sequence so we mask the loss + if idx is None: + loss_mask = torch.zeros_like(loss_mask) + + # aux dataset + aux_loss_mask, _ = self._query_document_sample_shuffle_indices_aux_dataset( + self.loss_mask_dataset, idx + ) + if aux_loss_mask is not None: + if idx % 100 == 0: + print(f"> Used aux_loss_mask at current sample={idx} ...") + loss_mask = torch.from_numpy(aux_loss_mask).float()[1:].contiguous() + + if self.config.create_attention_mask: + return { + "tokens": tokens, + "labels": labels, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "position_ids": position_ids, + } + else: + return { + "tokens": tokens, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + } + + def _query_document_sample_shuffle_indices_aux_dataset( + self, aux_dataset, idx: int + ) -> Tuple[numpy.ndarray, numpy.ndarray]: + """Get the aux ids and document ids for a given index + + Args: + aux_dataset (int): The aux dataset + idx (int): The index into the dataset + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids + """ + if aux_dataset is None: + return (None, None) + + # Do the shuffle mapping + idx = self.shuffle_index[idx] + + # Get the beginning and end documents and offsets + doc_index_beg, doc_index_beg_offset = self.sample_index[idx] + doc_index_end, doc_index_end_offset = self.sample_index[idx + 1] + + document_ids = [] + sample_parts = [] + + # Sample spans a single document + if doc_index_beg == doc_index_end: + # Add the document id + document_ids.append(self.document_index[doc_index_beg]) + + # Add the entire sample + sample_parts.append( + aux_dataset.get( + self.document_index[doc_index_beg], + offset=doc_index_beg_offset, + length=doc_index_end_offset - doc_index_beg_offset + 1, + ) + ) + + # Sample spans multiple documents + else: + for i in range(doc_index_beg, doc_index_end + 1): + # Add the document id + document_ids.append(self.document_index[i]) + + # Add the sample part + offset = 0 if i > doc_index_beg else doc_index_beg_offset + length = None if i < doc_index_end else doc_index_end_offset + 1 + sample_parts.append( + aux_dataset.get(self.document_index[i], offset=offset, length=length) + ) + + return ( + numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64), + numpy.array(document_ids, dtype=numpy.int64), + ) diff --git a/flagscale/train/megatron/train_gpt.py b/flagscale/train/megatron/train_gpt.py index ea89a6b21e..ef897914a6 100644 --- a/flagscale/train/megatron/train_gpt.py +++ b/flagscale/train/megatron/train_gpt.py @@ -88,6 +88,8 @@ def loss_func( # Check individual rank losses are not NaN prior to DP all-reduce. rerun_state_machine = get_rerun_state_machine() + if args.enable_simulator: + args.check_for_nan_in_loss_and_grad = False if args.check_for_nan_in_loss_and_grad: rerun_state_machine.validate_result( result=loss, diff --git a/flagscale/train/megatron/training/arguments.py b/flagscale/train/megatron/training/arguments.py index 9d4ffc2131..096fd5466c 100644 --- a/flagscale/train/megatron/training/arguments.py +++ b/flagscale/train/megatron/training/arguments.py @@ -2659,8 +2659,10 @@ def _add_distributed_args(parser): default=False, help='if set, overlap pipeline parallel communication in warmup and flush', dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo', 'flagcx'], + choices=['nccl', 'gloo', 'flagcx', 'dummy'], help='Which backend to use for distributed training.') + group.add_argument('--enable-simulator', action='store_true', + help='Use single process to simulate the distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Default timeout minutes for torch.distributed.') group.add_argument('--distributed-timeout-seconds-after-init', type=int, default=None, diff --git a/flagscale/train/megatron/training/arguments_fs.py b/flagscale/train/megatron/training/arguments_fs.py index cf8ce9e445..986260fa2e 100644 --- a/flagscale/train/megatron/training/arguments_fs.py +++ b/flagscale/train/megatron/training/arguments_fs.py @@ -13,6 +13,14 @@ warnings.warn( "flagcx is not installed, you can't use flagcx backend for communication.", ImportWarning ) + +import datetime +import multiprocessing +import os +import threading + +import dummy_collectives + from megatron.plugin.hetero.parallel_context import RankMapper class FSTrainArguments: @@ -52,23 +60,56 @@ def _initialize_distributed(self): device_id = torch.device(f"cuda:{args.local_rank}") else: device_id = None - - # Call the init process - init_process_group_kwargs = { - "backend": args.distributed_backend, - "world_size": args.world_size, - "rank": args.rank, - "timeout": timedelta(minutes=args.distributed_timeout_minutes), - } - if args.distributed_backend == "flagcx": - init_process_group_kwargs["backend"] = "cpu:gloo,cuda:flagcx" - # for communication based cpu - if args.enable_hetero and args.hetero_use_cpu_communication: - # if not all(device_type == args.hetero_device_types[0] for device_type in args.hetero_device_types): - # init_process_group_kwargs['backend'] = 'cpu:gloo' - # Force the group of backend gloo only support cpu - init_process_group_kwargs["backend"] = "cpu:gloo" - torch.distributed.init_process_group(**init_process_group_kwargs) + if args.enable_simulator: + # Define a function to initialize and run operations with a virtual rank + def run_virtual_rank(rank, world_size, timeout): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "37832" + init_process_group_kwargs = { + 'backend': args.distributed_backend, + 'world_size': world_size, + 'rank': rank, + 'timeout': datetime.timedelta(minutes=timeout), + } + torch.distributed.init_process_group(**init_process_group_kwargs) + torch.distributed.barrier() + + # Call the init process with multithreads + args.distributed_timeout_minutes = 1 + threads = [] + # Start a thread for each virtual rank + # for rank in range(1, 2): # 2 for skipping launching thousands of threads + for rank in range(1, args.world_size): + thread = threading.Thread( + target=run_virtual_rank, + args=(rank, args.world_size, args.distributed_timeout_minutes), + ) + thread.start() + threads.append(thread) + rank = 0 + gpu_task = multiprocessing.Process( + target=run_virtual_rank, + args=(rank, args.world_size, args.distributed_timeout_minutes), + ) + gpu_task.start() + # Wait for all threads to complete + for thread in threads: + thread.join() + else: + # Call the init process + init_process_group_kwargs = { + "backend": args.distributed_backend, + "world_size": args.world_size, + "rank": args.rank, + "timeout": timedelta(minutes=args.distributed_timeout_minutes), + } + # for communication based cpu + if args.enable_hetero and args.hetero_use_cpu_communication: + # if not all(device_type == args.hetero_device_types[0] for device_type in args.hetero_device_types): + # init_process_group_kwargs['backend'] = 'gloo' + # Force the group of backend gloo only support cpu + init_process_group_kwargs['backend'] = 'cpu:gloo' + torch.distributed.init_process_group(**init_process_group_kwargs) def _build_rank_mapper(self): self._initialize_distributed() diff --git a/flagscale/train/megatron/training/training.py b/flagscale/train/megatron/training/training.py index e17a8da153..e3b50a5565 100644 --- a/flagscale/train/megatron/training/training.py +++ b/flagscale/train/megatron/training/training.py @@ -1427,7 +1427,14 @@ def setup_model_and_optimizer( config = None para_ctx = get_parallel_context() if para_ctx is not None: - config, config_overrides = para_ctx.get_optimizer_config() + #config, config_overrides = para_ctx.get_optimizer_config() + optimizer_cfg = para_ctx.get_optimizer_config() + + if isinstance(optimizer_cfg, tuple): + config, config_overrides = optimizer_cfg + else: + config = optimizer_cfg + config_overrides = None if config is None: config, config_overrides = get_megatron_optimizer_config(args) @@ -1595,6 +1602,9 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch optim_instance._copy_main_params_to_param_buffer() # Forward pass. + # =================== Forward + Backward timing =================== + torch.cuda.synchronize() + t_fwd_start = time.time() losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iterator, @@ -1606,6 +1616,12 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch forward_only=False, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, ) + torch.cuda.synchronize() + t_fwd_end = time.time() + fwd_time = t_fwd_end - t_fwd_start + bwd_time = fwd_time * 2.0 + print(f"[simulatior output] forward: {fwd_time:.2f}, backward: {bwd_time:.2f}", flush=True) + # ================================================================ should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: return {}, True, should_checkpoint, should_exit, exit_code, None, None @@ -2663,7 +2679,7 @@ def get_e2e_base_metrics(): model, optimizer, iteration, ref_state_dict, ) train_data_iterator = buffered_rollouts - + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 ft_integration.on_training_step_start() ( loss_dict, @@ -3280,6 +3296,9 @@ def build_train_valid_test_data_loaders(build_train_valid_test_datasets_provider args.do_test = getattr(args, "do_test", False) or flags[2].item() if getattr(args, 'perform_rl_step', False): args.to_test = False + + if args.enable_simulator: + args.do_train = 1 return train_dataloader, valid_dataloaders, test_dataloader From ac3db5e8dec6f80d8e63e9fb95405870fb1ba6a7 Mon Sep 17 00:00:00 2001 From: shuailong616 <452509829@qq.com> Date: Thu, 8 Jan 2026 18:24:33 +0800 Subject: [PATCH 2/3] code format --- .../simulator/analylize_pipeline_time.py | 72 +++++++++---------- .../runner/auto_tuner/simulator/config_gen.py | 10 +-- 2 files changed, 39 insertions(+), 43 deletions(-) diff --git a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py index efe28be9cc..f975b39800 100644 --- a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py +++ b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py @@ -2,7 +2,9 @@ import re import subprocess import time + from collections import defaultdict + # from megatron.training import get_args @@ -120,18 +122,21 @@ def compute_pipeline_parallelism_cost( LAYER_RE = re.compile(r"decoder\.layers\.(\d+)\.(.+)") + def extract_stage_ops_from_raw_log(log_text: str): - layers = defaultdict(lambda: { - "has_attention": False, - "has_mlp": False, - "has_qkv": False, - "has_proj": False, - "has_fc1": False, - "has_fc2": False, - }) + layers = defaultdict( + lambda: { + "has_attention": False, + "has_mlp": False, + "has_qkv": False, + "has_proj": False, + "has_fc1": False, + "has_fc2": False, + } + ) for raw_line in log_text.splitlines(): - line = raw_line.strip() + line = raw_line.strip() if "decoder.layers." not in line: continue @@ -163,6 +168,7 @@ def extract_stage_ops_from_raw_log(log_text: str): return layers + def tp_collectives_per_stage(layers, sequence_parallel=False): total = 0 per_layer = {} @@ -178,11 +184,8 @@ def tp_collectives_per_stage(layers, sequence_parallel=False): return total, per_layer -def tp_collectives_per_layer( - has_attention=True, - has_mlp=True, - sequence_parallel=False -): + +def tp_collectives_per_layer(has_attention=True, has_mlp=True, sequence_parallel=False): cnt = 0 if has_attention: cnt += 1 # qkv backward @@ -194,26 +197,19 @@ def tp_collectives_per_layer( cnt += 4 # ln fwd/bwd rs + ag return cnt + def ring_allreduce_time( - n_bytes, - N_ranks, - N_nodes, - alpha_base, - alpha_intra, - alpha_inter, - hops, - alpha_switch, - beta, + n_bytes, N_ranks, N_nodes, alpha_base, alpha_intra, alpha_inter, hops, alpha_switch, beta ): - alpha_hw = ( - 2 * (N_ranks - N_nodes) * alpha_intra - + 2 * (N_nodes - 1) * (alpha_inter * hops * alpha_switch) + alpha_hw = 2 * (N_ranks - N_nodes) * alpha_intra + 2 * (N_nodes - 1) * ( + alpha_inter * hops * alpha_switch ) bw_term = 2 * (N_ranks - 1) / N_ranks * n_bytes * beta return alpha_base + alpha_hw + bw_term + def stage_has_tp_from_process_mesh(process_mesh): assert len(process_mesh) % 5 == 0 @@ -221,7 +217,7 @@ def stage_has_tp_from_process_mesh(process_mesh): stage_id = 0 for i in range(0, len(process_mesh), 5): - device = process_mesh[i:i+5] + device = process_mesh[i : i + 5] tp = device[0] pp = device[4] @@ -243,10 +239,10 @@ def simulator( ): # os.environ["PYTHONPATH"] = "/share/project/heyongzhe/FlagScale/megatron:/share/project/heyongzhe/FlagScale" - #os.environ["PYTHONPATH"] = ( + # os.environ["PYTHONPATH"] = ( # "/workspace/single_process_simulator_nd/FlagScale:" # "/workspace/single_process_simulator_nd/FlagScale/third_party/Megatron-LM" - #) + # ) os.environ["ENABLE_SIMULATOR"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "3" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -311,31 +307,31 @@ def simulator( s_out = extract_stage_ops_from_raw_log(output) reduce_op_cnt = tp_collectives_per_stage(s_out)[0] - n_bytes = 16.7 - n_rank = 2 + n_rank = 2 n_nodes = 1 alpha_base = 5e-6 alpha_intra = 1e-6 alpha_inter = 3e-6 hops = 1 if n_nodes > 1 else 0 alpha_switch = 1.2 - beta = 1 / (25 * 1024**3) - - time = ring_allreduce_time(n_bytes, n_rank, n_nodes, alpha_base, alpha_intra, alpha_inter, hops, alpha_switch, beta) - + beta = 1 / (25 * 1024**3) + + time = ring_allreduce_time( + n_bytes, n_rank, n_nodes, alpha_base, alpha_intra, alpha_inter, hops, alpha_switch, beta + ) + fw_cm_time = time * reduce_op_cnt * 0.33 bw_cm_time = time * reduce_op_cnt * 0.66 stp = stage_has_tp_from_process_mesh(process_mesh) - if match: fwd_time = float(match.group(1)) bwd_time = float(match.group(2)) # comm_time = float(match.group(3)) comm_time = estimate_comm_time_between_stages(1, 2048, 4096) - if stp[stage]: - fwd_time += fw_cm_time + if stp[stage]: + fwd_time += fw_cm_time bwd_time += bw_cm_time print("forward:", fwd_time) print("backward:", bwd_time) diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py index c5b81c2047..620656e1c7 100644 --- a/flagscale/runner/auto_tuner/simulator/config_gen.py +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -8,8 +8,8 @@ import analylize_pipeline_time # from itertools import product -#import flagscale.train.theoretical_memory_usage as mem_usg -import flagscale.train.megatron.training.theoretical_memory_usage as mem_usg +# import flagscale.train.theoretical_memory_usage as mem_usg +import flagscale.train.megatron.training.theoretical_memory_usage as mem_usg BYTES_OF_GB = 10**9 @@ -399,7 +399,7 @@ def append_config_to_file(file_path: str, config: dict): # hetero_configs = [] # generate all possible and legal mesh configs, each element of hetero_configs is a mesh list - + gen_hetero_configs( device_type_list=device_type_list, device_num_list=device_num_list, @@ -409,7 +409,7 @@ def append_config_to_file(file_path: str, config: dict): # num_micro_batches=num_micro_batches, # hetero_configs=hetero_configs ) - + # assert False # simulation file_path = "hetero_configs.json" @@ -419,7 +419,7 @@ def append_config_to_file(file_path: str, config: dict): print(hetero_config) pp_cost = hetero_config['simulated_time'] = analylize_pipeline_time.analyze_pp_time( # pp_cost = hetero_config.simulated_time = analylize_pipeline_time.analyze_pp_time( - #scheme="vpp", + # scheme="vpp", scheme="1F1B", num_micro_batches=num_micro_batches, process_mesh=hetero_config['mesh'], From 6c130dfd5c80af6ed620dc48096fe5c121853275 Mon Sep 17 00:00:00 2001 From: shuailong <80105174+shuailong616@users.noreply.github.com> Date: Fri, 9 Jan 2026 09:14:05 +0800 Subject: [PATCH 3/3] Modify README for PYTHONPATH setup Updated PYTHONPATH instructions and removed unnecessary lines. --- flagscale/runner/auto_tuner/simulator/README.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/flagscale/runner/auto_tuner/simulator/README.md b/flagscale/runner/auto_tuner/simulator/README.md index 4ac43d38ce..930beca8e4 100644 --- a/flagscale/runner/auto_tuner/simulator/README.md +++ b/flagscale/runner/auto_tuner/simulator/README.md @@ -20,13 +20,8 @@ num_layers = 4 a. set PYTHONPATH ``` export PYTHONPATH=/***/FlagScale:$PYTHONPATH - export PYTHONPATH=$PYTHONPATH:/***/FlagScale/third_party/Megatron-LM - - vim /***/FlagScale/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py - os.environ["PYTHONPATH"] = ( - "/***/FlagScale:" - "/***/FlagScale/third_party/Megatron-LM" - ) + export PYTHONPATH=/***/FlagScale/flagscale/train:$PYTHONPATH + ``` b. run