Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
std::default_delete<const void*[]>());


// Initialize all padding entries. For partial tiles (m < m_step),
// the kai LHS packing kernel may still read pointer entries beyond the logically
// filled 'm' positions. Leaving these uninitialized can cause non-deterministic
// reads and corrupt packed LHS data.
auto lhs_ptrs_ = lhs_ptrs.get();
std::fill(lhs_ptrs_, lhs_ptrs_ + (lhs_ptrs_k * lhs_ptrs_m), reinterpret_cast<const void*>(&pad_ptr[0]));

auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1);
auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1);

Expand Down Expand Up @@ -460,7 +467,23 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
// figure out how many blocks needed to correctly fill padding
padsize = ((ci + padsize - 1) / padsize) * padsize;
}
static std::vector<float>pad_ptr(padsize, 0.f);

// pad_ptr must be at least 'ci' floats for padding pixels.
// Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct.
thread_local std::vector<float> pad_ptr;
const float* old_pad_ptr = pad_ptr.data();
bool has_pad_ptr_changed = false;

if (pad_ptr.size() < padsize) {
pad_ptr.resize(padsize, 0.f);
if (pad_ptr.data() != old_pad_ptr) {
has_pad_ptr_changed = true;
}
} else {
// Ensure any previously-used region remains zeroed (grow-only means it should already be zeros,
// but keep this explicit for safety).
std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f);
}

LhsCacheKey key = {
ci, ih, iw,
Expand All @@ -481,6 +504,16 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
thread_local std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;

if (has_pad_ptr_changed)
{
// If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they
// would be referencing the old pad buffer.
// See discussion in https://github.com/microsoft/onnxruntime/pull/27214.
// TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key
// or any other approach that would reduce unnecessary cache invalidations.
lhs_ptrs_cache.clear();
}

std::shared_ptr<const void*[]> lhs_ptrs;
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
lhs_ptrs = found->second;
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,18 @@ Status ConvTranspose<T, Layout>::UpdateState(OpKernelContext* context, bool dyna
" group: ", conv_transpose_attrs_.group);
}

// Bias shape validation (It should be a 1D tensor with size M)
// See https://github.com/microsoft/onnxruntime/issues/26144
if (B != nullptr) {
if (B->Shape().NumDimensions() != 1 || B->Shape()[0] != num_output_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Bias shape is not compatible with number of output channels."
" It should be a 1-D tensor with size num_output_channels(M).",
" Bias: ", B->Shape().ToString().c_str(),
" num_output_channels: ", num_output_channels);
}
}

TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc));

Expand Down
Loading