Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,9 @@ static bool append_fanin_or_fail(
// keyed to the live producer, and doing it before the ++ still suppresses a
// double-count for a producer named twice in one submission.
prod_state->lock_fanout();
PTO2TaskState pstate = prod_state->task_state.load(std::memory_order_acquire);
bool gone = prod_state->task == nullptr || prod_state->task->task_id.local() != producer_task_id.local() ||
prod_state->task_state.load(std::memory_order_acquire) == PTO2_TASK_CONSUMED;
pstate == PTO2_TASK_CONSUMED;
bool claim = !gone && !fanin_builder->mark_seen(prod_ring, prod_slot);
if (claim) {
// Low bits hold the consumer count; bit31 is the scope ref. The consumer
Expand Down Expand Up @@ -345,6 +346,43 @@ static bool append_fanin_or_fail(
return true;
}

static bool all_claimed_fanin_completed(const PTO2FaninBuilder &fanin_builder) {
if (fanin_builder.count == 0) return true;
return fanin_builder.for_each([](PTO2TaskSlotState *producer) -> bool {
return producer != nullptr &&
producer->task_state.load(std::memory_order_acquire) >= PTO2_TASK_COMPLETED;
});
}

static bool route_orch_inline_ready(PTO2SchedulerState *sched, PTO2TaskSlotState &slot_state) {
if (slot_state.active_mask.to_shape() == PTO2ResourceShape::DUMMY) {
if (!sched->try_claim_ready_once(slot_state)) return false;
sched->dummy_ready_queue.push(&slot_state);
return true;
}
return sched->route_ready_once(slot_state);
}

static bool try_orch_prewire_task(PTO2SchedulerState *sched, PTO2TaskSlotState &slot_state, int32_t wfanin) {
if (wfanin <= 0) return false;

auto &rss = sched->ring_sched_states[slot_state.ring_id];
if (!rss.try_lock_dep_pool()) return false;

bool can_prewire = true;
if (rss.dep_pool.available() < wfanin) {
int32_t sm_last_alive = rss.ring->fc.last_task_alive.load(std::memory_order_acquire);
rss.dep_pool.reclaim(*rss.ring, sm_last_alive);
can_prewire = rss.dep_pool.available() >= wfanin;
}

if (can_prewire) {
sched->wire_task(rss, &slot_state, wfanin);
}
rss.unlock_dep_pool();
return can_prewire;
}

static void scope_tasks_push(PTO2OrchestratorState *orch, PTO2TaskSlotState *task_slot_state);

struct PTO2PreparedTask {
Expand Down Expand Up @@ -437,6 +475,7 @@ static bool prepare_task(
out->slot_state->bind_ring(ring_id);
out->slot_state->reset_for_reuse();
out->slot_state->fanin_count = 0;
out->slot_state->dep_pool_mark = 0;

out->payload->prefetch(args.tensor_count(), args.scalar_count());

Expand Down Expand Up @@ -864,12 +903,33 @@ static TaskOutputTensors submit_task_common(

CYCLE_COUNT_LAP(g_orch_args_cycle);

// === STEP 6: push to wiring queue ===
// Deferred wiring: orchestrator only stores dependency metadata and increments
// fanout_count. The actual fanout_head wiring (lock + dep_pool + early_finished)
// is handled asynchronously by scheduler thread 0 via the wiring queue.
// === STEP 6: inline already-ready tasks or push to wiring queue ===
// Zero-fanin tasks and tasks whose claimed producers are already completed
// do not need fanout links or dep_pool entries, so O can make them ready
// immediately and skip S-side wiring.
bool inline_wired = false;
if (fanin_builder.count == 0) {
cur_slot_state.fanin_count = 1;
cur_slot_state.fanin_refcount.store(1, std::memory_order_release);
(void)route_orch_inline_ready(sched, cur_slot_state);
inline_wired = true;
} else if (all_claimed_fanin_completed(fanin_builder)) {
int32_t ready_seed = fanin_builder.count + 1;
cur_slot_state.fanin_count = ready_seed;
payload.dispatch_fanin.store(fanin_builder.count, std::memory_order_release);
cur_slot_state.fanin_refcount.store(ready_seed, std::memory_order_release);
(void)route_orch_inline_ready(sched, cur_slot_state);
inline_wired = true;
} else if (try_orch_prewire_task(sched, cur_slot_state, fanin_builder.count)) {
inline_wired = true;
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

// Deferred wiring: for tasks that still need live producer fanout links,
// orchestrator only stores dependency metadata and increments fanout_count.
// The actual fanout_head wiring (lock + dep_pool + early_finished) is
// handled asynchronously by scheduler thread 0 via the wiring queue.
// Push to global wiring queue — scheduler sets fanin_count, wires fanout, checks readiness
if (!sched->wiring.queue.push(&cur_slot_state)) {
if (!inline_wired && !sched->wiring.queue.push(&cur_slot_state)) {
// producer_blocked is the wiring deadlock detector's "orchestrator is
// stuck in push" observable: set ONLY while we actually spin (queue
// full), cleared on exit, so the just-filled-then-scope_end case (push
Expand Down Expand Up @@ -1071,7 +1131,7 @@ TaskOutputTensors PTO2OrchestratorState::alloc_tensors(const L0TaskArgs &args) {
// required so scope_end can release the producer-side reference and
// drive the slot to CONSUMED, but worker dispatch fields are never
// observed for hidden alloc tasks.
prepared.slot_state->task_state.store(PTO2_TASK_COMPLETED, std::memory_order_release);
prepared.slot_state->mark_completed();
}
orch->inline_completed_tasks++;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,15 @@ static_assert(
// never reach -> provable deadlock.
static constexpr uint32_t PTO2_FANOUT_SCOPE_BIT = 0x80000000u;

enum PTO2ReadyState : uint8_t {
PTO2_READY_UNCLAIMED = 0,
PTO2_READY_CLAIMED = 1,
};

enum PTO2CompletionFlag : uint8_t {
PTO2_COMPLETION_DONE = 2,
};

struct alignas(64) PTO2TaskSlotState {
// Fanout lock + list (accessed together under lock in on_task_complete)
std::atomic<int32_t> fanout_lock; // Per-task spinlock (0=unlocked, 1=locked)
Expand Down Expand Up @@ -437,7 +446,7 @@ struct alignas(64) PTO2TaskSlotState {
// sequenced before on_subtask_complete's acq_rel fetch_add and the read
// after, so all earlier subtasks' writes are visible to the last subtask.
std::atomic<bool> any_subtask_deferred{false};
uint8_t _async_pad{0};
std::atomic<uint8_t> ready_state{PTO2_READY_UNCLAIMED};
int32_t dep_pool_mark{0}; // Dep pool top after wiring (thread-0-only)

std::atomic<int16_t> completed_subtasks{0}; // Each core completion increments by 1
Expand Down Expand Up @@ -467,6 +476,15 @@ struct alignas(64) PTO2TaskSlotState {
task = t;
}

void mark_completed() {
task_state.store(PTO2_TASK_COMPLETED, std::memory_order_release);
ready_state.fetch_or(PTO2_COMPLETION_DONE, std::memory_order_release);
}

bool is_completion_flag_set() const {
return (ready_state.load(std::memory_order_acquire) & PTO2_COMPLETION_DONE) != 0;
}

/**
* Reset dynamic scheduling fields for slot reuse.
* Called by advance_ring_pointers() after a slot transitions to CONSUMED
Expand All @@ -487,6 +505,7 @@ struct alignas(64) PTO2TaskSlotState {
completed_subtasks.store(0, std::memory_order_relaxed);
next_block_idx.store(0, std::memory_order_relaxed);
any_subtask_deferred.store(false, std::memory_order_relaxed);
ready_state.store(PTO2_READY_UNCLAIMED, std::memory_order_relaxed);
// Note: payload spec fields (spec_state / staged_core_mask / dispatch_fanin /
// spec_chain_*) are NOT reset here — this method skips the payload by
// contract. They are (re)initialized in PTO2TaskPayload::init on every
Expand Down
Loading