Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,12 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m],
std::default_delete<const void*[]>());

// Initialize all padding entries. For partial tiles (m < m_step),
// the kai LHS packing kernel may still read pointer entries beyond the logically
// filled 'm' positions. Leaving these uninitialized can cause non-deterministic
// reads and corrupt packed LHS data.
auto lhs_ptrs_ = lhs_ptrs.get();
std::fill(lhs_ptrs_, lhs_ptrs_ + (lhs_ptrs_k * lhs_ptrs_m), reinterpret_cast<const void*>(&pad_ptr[0]));

auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1);
auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1);
Expand Down Expand Up @@ -430,7 +436,6 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
};

size_t m_{0};
auto lhs_ptrs_ = lhs_ptrs.get();
for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) {
for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) {
size_t k_{0};
Expand Down Expand Up @@ -460,7 +465,23 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
// figure out how many blocks needed to correctly fill padding
padsize = ((ci + padsize - 1) / padsize) * padsize;
}
static std::vector<float>pad_ptr(padsize, 0.f);

// pad_ptr must be at least 'ci' floats for padding pixels.
// Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct.
thread_local std::vector<float> pad_ptr;
const float* old_pad_ptr = pad_ptr.data();
bool has_pad_ptr_changed = false;

if (pad_ptr.size() < padsize) {
pad_ptr.resize(padsize, 0.f);
if (pad_ptr.data() != old_pad_ptr) {
has_pad_ptr_changed = true;
}
} else {
// Ensure any previously-used region remains zeroed (grow-only means it should already be zeros,
// but keep this explicit for safety).
std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f);
}

LhsCacheKey key = {
ci, ih, iw,
Expand All @@ -481,6 +502,16 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
thread_local std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;

if (has_pad_ptr_changed)
{
// If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they
// would be referencing the old pad buffer.
// See discussion in https://github.com/microsoft/onnxruntime/pull/27214.
// TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key
// or any other approach that would reduce unnecessary cache invalidations.
lhs_ptrs_cache.clear();
}

std::shared_ptr<const void*[]> lhs_ptrs;
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
lhs_ptrs = found->second;
Expand Down
Loading