Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ DEFINE_string(kv_cache_transfer_mode,
DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");

DEFINE_bool(enable_shm,
true,
false,
"Whether to enable shared memory for executing model.");
// --- function call config ---

Expand Down
4 changes: 4 additions & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class Options {
// for offline inference: the path to spawn worker binary
PROPERTY(std::string, spawn_worker_path) = "";

// use shared memory for inter-process communication in the single-machine
// multi-GPU scenario.
PROPERTY(bool, enable_shm) = false;

// whether the worker and master are on the same machine.
PROPERTY(bool, is_local) = false;
};
Expand Down
10 changes: 6 additions & 4 deletions xllm/core/distributed_runtime/dist_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,20 @@ void DistManager::setup_single_node_workers(const runtime::Options& options) {
namespace {
std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,
int r,
int dp_local_tp_size) {
int dp_local_tp_size,
const runtime::Options& options) {
std::unique_ptr<CommChannel> channel;

if (net::extract_ip(FLAGS_master_node_addr) ==
net::extract_ip(worker_addrs) &&
FLAGS_enable_shm) {
options.enable_shm()) {
// create shared memory manager for local rank
bool is_driver = false;
int dp_group = r / dp_local_tp_size;
if (r % dp_local_tp_size == 0) {
is_driver = true;
}
channel = std::make_unique<ShmChannel>(dp_group, r, is_driver);
channel = std::make_unique<ShmChannel>(dp_group, r, is_driver, options);
} else {
channel = std::make_unique<CommChannel>();
}
Expand Down Expand Up @@ -220,7 +221,8 @@ void DistManager::setup_multi_node_workers(
<< r;
return;
}
auto channel = create_channel(worker_addrs_map[r], r, dp_local_tp_size);
auto channel =
create_channel(worker_addrs_map[r], r, dp_local_tp_size, options);
worker_clients_.emplace_back(
std::make_unique<RemoteWorker>(r,
worker_addrs_map[r],
Expand Down
12 changes: 8 additions & 4 deletions xllm/core/distributed_runtime/shm_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ limitations under the License.

namespace xllm {

ShmChannel::ShmChannel(int dp_group, int rank, bool is_driver) {
ShmChannel::ShmChannel(int dp_group,
int rank,
bool is_driver,
const runtime::Options& options)
: enable_shm_(options.enable_shm()) {
bool is_creator;

if (is_driver) {
Expand All @@ -45,7 +49,7 @@ bool ShmChannel::execute_model_with_shm(
int use_shm_ret = input_shm_manager_->raw_input_write(inputs);
if (use_shm_ret < 0) {
// fallback
FLAGS_enable_shm = false;
enable_shm_ = false;
LOG(ERROR)
<< "RemoteWorker SharedMemoryManager write failed, fallback to brpc.";
return false;
Expand All @@ -58,7 +62,7 @@ bool ShmChannel::execute_model_with_shm(
void ShmChannel::execute_model_async(
const std::vector<RawForwardInput>& inputs,
folly::Promise<std::optional<RawForwardOutput>>& promise) {
if (FLAGS_enable_shm) {
if (enable_shm_) {
// write to shared memory, then wait output.
RawForwardOutput raw_output;
bool shm_success = execute_model_with_shm(inputs, raw_output);
Expand All @@ -69,4 +73,4 @@ void ShmChannel::execute_model_async(
}
execute_model_with_brpc(inputs, promise);
}
} // namespace xllm
} // namespace xllm
10 changes: 8 additions & 2 deletions xllm/core/distributed_runtime/shm_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ limitations under the License.
#pragma once
#include "comm_channel.h"
#include "runtime/forward_shared_memory_manager.h"
#include "runtime/options.h"

namespace xllm {

class ShmChannel : public CommChannel {
public:
explicit ShmChannel(int dp_group, int rank, bool is_driver);
explicit ShmChannel(int dp_group,
int rank,
bool is_driver,
const runtime::Options& options);
~ShmChannel() = default;

void execute_model_async(
Expand All @@ -31,8 +35,10 @@ class ShmChannel : public CommChannel {
private:
bool execute_model_with_shm(const std::vector<RawForwardInput>& inputs,
RawForwardOutput& raw_output);

bool enable_shm_ = false;
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager_ = nullptr;
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager_ = nullptr;
};

} // namespace xllm
} // namespace xllm
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
int world_size,
int device_idx,
int num_decoding_tokens,
int block_size) {
int block_size,
bool enable_shm) {
// TODO: pass whole xllm::runtime::Options here from main process.
xllm::runtime::Options runner_options;
runner_options.block_size(block_size)
.num_decoding_tokens(num_decoding_tokens)
.enable_schedule_overlap(false)
.enable_offline_inference(true)
.master_node_addr(master_node_addr);
.master_node_addr(master_node_addr)
.enable_shm(enable_shm);
FLAGS_enable_schedule_overlap = false;
FLAGS_master_node_addr = master_node_addr;
FLAGS_block_size = block_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class SpawnWorkerServer final {
int world_size,
int device_idx,
int num_decoding_tokens,
int block_size);
int block_size,
bool enable_shm);

~SpawnWorkerServer() = default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ limitations under the License.
// @device_idx
// @num_decoding_tokens
// @block_size
// @enable_shm
int main(int argc, char* argv[]) {
if (argc < 7) {
if (argc < 8) {
LOG(ERROR)
<< "Spwan worker process receive wrong args. Need 7 args, receive "
<< "Spwan worker process receive wrong args. Need 8 args, receive "
<< argc;
return 1;
}
Expand All @@ -50,22 +51,25 @@ int main(int argc, char* argv[]) {
int device_idx = atoi(argv[5]);
int num_decoding_tokens = atoi(argv[6]);
int block_size = atoi(argv[7]);
int enable_shm = atoi(argv[8]);

LOG(INFO) << "Spwan worker: "
<< "master_node_addr = " << master_node_addr
<< ", local_rank = " << local_rank
<< ", world_size = " << world_size
<< ", device_idx = " << device_idx
<< ", num_decoding_tokens = " << num_decoding_tokens
<< ", block_size = " << block_size << "\n";
<< ", block_size = " << block_size
<< ", enable_shm = " << (enable_shm > 0) << "\n";

xllm::SpawnWorkerServer worker(master_node_addr,
local_rank,
global_rank,
world_size,
device_idx,
num_decoding_tokens,
block_size);
block_size,
enable_shm > 0);

worker.run();

Expand Down
43 changes: 23 additions & 20 deletions xllm/core/distributed_runtime/worker_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void WorkerServer::create_server(
std::unique_ptr<Worker> worker =
std::make_unique<Worker>(*parallel_args, device, options, worker_type);
worker_service->set_worker(std::move(worker));
if (FLAGS_enable_shm && input_shm_manager && output_shm_manager) {
if (options.enable_shm() && input_shm_manager && output_shm_manager) {
worker_service->create_polling_shm_thread(std::move(input_shm_manager),
std::move(output_shm_manager));
}
Expand All @@ -127,29 +127,32 @@ void WorkerServer::create_spawn_server(int local_rank,
const ParallelArgs& parallel_args,
const torch::Device& d,
const runtime::Options& options) {
auto local_rank_str0 = std::to_string(local_rank);
const char* local_rank_str = local_rank_str0.c_str();
auto global_rank_str0 = std::to_string(parallel_args.rank());
const char* global_rank_str = global_rank_str0.c_str();
auto world_size_str0 = std::to_string(parallel_args.world_size());
const char* world_size_str = world_size_str0.c_str();
auto device_idx_str0 = std::to_string(d.index());
const char* device_idx_str = device_idx_str0.c_str();
auto num_decoding_tokens_str0 = std::to_string(options.num_decoding_tokens());
const char* num_decoding_tokens_str = num_decoding_tokens_str0.c_str();
auto block_size_str0 = std::to_string(options.block_size());
const char* block_size_str = block_size_str0.c_str();
auto local_rank_str = std::to_string(local_rank);
const char* local_rank_ptr = local_rank_str.c_str();
auto global_rank_str = std::to_string(parallel_args.rank());
const char* global_rank_ptr = global_rank_str.c_str();
auto world_size_str = std::to_string(parallel_args.world_size());
const char* world_size_ptr = world_size_str.c_str();
auto device_idx_str = std::to_string(d.index());
const char* device_idx_ptr = device_idx_str.c_str();
auto num_decoding_tokens_str = std::to_string(options.num_decoding_tokens());
const char* num_decoding_tokens_ptr = num_decoding_tokens_str.c_str();
auto block_size_str = std::to_string(options.block_size());
const char* block_size_ptr = block_size_str.c_str();
auto enable_shm_str = std::to_string(options.enable_shm());
const char* enable_shm_ptr = enable_shm_str.c_str();
std::string spawn_worker_bin_path =
options.spawn_worker_path() + "/spawn_worker";
LOG(INFO) << "Spawn worker path: " << spawn_worker_bin_path;
const char* argv[] = {spawn_worker_bin_path.c_str(),
master_node_addr.c_str(),
local_rank_str,
global_rank_str,
world_size_str,
device_idx_str,
num_decoding_tokens_str,
block_size_str,
local_rank_ptr,
global_rank_ptr,
world_size_ptr,
device_idx_ptr,
num_decoding_tokens_ptr,
block_size_ptr,
enable_shm_ptr,
nullptr};
pid_t pid;
posix_spawn_file_actions_init(&file_actions_);
Expand All @@ -173,7 +176,7 @@ void WorkerServer::prepare_shm(
const runtime::Options& options,
std::unique_ptr<ForwardSharedMemoryManager>& input_shm_manager,
std::unique_ptr<ForwardSharedMemoryManager>& output_shm_manager) {
if (options.is_local() && FLAGS_enable_shm) {
if (options.is_local() && options.enable_shm()) {
bool is_creator;
int dp_local_tp_size = parallel_args.world_size() / parallel_args.dp_size();
int dp_group = parallel_args.rank() / dp_local_tp_size;
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/runtime/master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
.enable_schedule_overlap(options_.enable_schedule_overlap())
.enable_offline_inference(options_.enable_offline_inference())
.spawn_worker_path(options_.spawn_worker_path())
.enable_shm(options_.enable_shm())
.is_local(options_.is_local());

auto engine = std::make_unique<VLMEngine>(eng_options);
Expand Down Expand Up @@ -154,6 +155,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
.enable_cache_upload(options_.enable_cache_upload())
.enable_offline_inference(options_.enable_offline_inference())
.spawn_worker_path(options_.spawn_worker_path())
.enable_shm(options_.enable_shm())
.is_local(options_.is_local());

if (options_.device_ip().has_value()) {
Expand Down Expand Up @@ -201,6 +203,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
.enable_continuous_kvcache(options_.enable_continuous_kvcache())
.enable_offline_inference(options_.enable_offline_inference())
.spawn_worker_path(options_.spawn_worker_path())
.enable_shm(options_.enable_shm())
.is_local(options_.is_local());

if (options_.device_ip().has_value()) {
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/runtime/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ struct Options {
// the path to spawn worker binary
PROPERTY(std::string, spawn_worker_path) = "";

// use shared memory for inter-process communication in the single-machine
// multi-GPU scenario.
PROPERTY(bool, enable_shm) = false;

// whether the worker and master are on the same machine.
PROPERTY(bool, is_local) = false;
};
Expand Down
1 change: 1 addition & 0 deletions xllm/pybind/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self):
self.parser.add_argument('--enable_multi_stream_parallel', action='store_true', help='Whether to enable computation communication overlap.')
self.parser.add_argument('--disable_ttft_profiling', action='store_true', help='Whether to disable TTFT profiling.')
self.parser.add_argument('--enable_forward_interruption', action='store_true', help='Whether to enable forward interruption.')
self.parser.add_argument('--enable_shm', action='store_true', help='Use shared memory for inter-process communication in the single-machine multi-GPU scenario.')

def parse_args(self):
return self.parser.parse_args()
3 changes: 2 additions & 1 deletion xllm/pybind/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ PYBIND11_MODULE(xllm_export, m) {
&Options::enable_forward_interruption_)
.def_readwrite("enable_offline_inference",
&Options::enable_offline_inference_)
.def_readwrite("spawn_worker_path", &Options::spawn_worker_path_);
.def_readwrite("spawn_worker_path", &Options::spawn_worker_path_)
.def_readwrite("enable_shm", &Options::enable_shm_);

// 2. export LLMMaster
py::class_<LLMMaster>(m, "LLMMaster")
Expand Down
2 changes: 2 additions & 0 deletions xllm/pybind/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
kv_cache_transfer_mode: str = 'PUSH',
disable_ttft_profiling: bool = False,
enable_forward_interruption: bool = False,
enable_shm: bool = False,
**kwargs,
) -> None:

Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
options.enable_forward_interruption = enable_forward_interruption
options.enable_offline_inference = True
options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
options.enable_shm = enable_shm
self.master = LLMMaster(options)

def finish(self):
Expand Down
2 changes: 2 additions & 0 deletions xllm/pybind/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
enable_disagg_pd: bool = False,
enable_schedule_overlap: bool = False,
kv_cache_transfer_mode: str = 'PUSH',
enable_shm: bool = False,
**kwargs,
) -> None:

Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
options.kv_cache_transfer_mode = kv_cache_transfer_mode
options.enable_offline_inference = True
options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
options.enable_shm = enable_shm
self.master = VLMMaster(options)

def finish(self):
Expand Down
1 change: 1 addition & 0 deletions xllm/xllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ int run() {
.max_global_tpot_ms(FLAGS_max_global_tpot_ms)
.max_requests_per_batch(FLAGS_max_requests_per_batch)
.enable_continuous_kvcache(FLAGS_enable_continuous_kvcache)
.enable_shm(FLAGS_enable_shm)
.is_local(is_local);

InstanceName::name()->set_name(options.instance_name().value_or(""));
Expand Down