Skip to content

Commit

Permalink
Revert "tiny improvments in convolve"
Browse files Browse the repository at this point in the history
This reverts commit 7281b30.
  • Loading branch information
sdpython committed Oct 25, 2021
1 parent a18e3be commit ef3dec8
Showing 1 changed file with 10 additions and 53 deletions.
63 changes: 10 additions & 53 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ Return Value:

size_t InputX = InitialInputX;
const float* InputRow = &Input[InputY * InputWidth];
const float* tempInputRow;

do {

Expand All @@ -173,28 +172,6 @@ Return Value:
}

CountX -= CountCopyX;
tempInputRow = &InputRow[InputX];

while (CountCopyX >= 16) {
MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(tempInputRow));
ColumnBuffer += 4;
tempInputRow += 4;

MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(tempInputRow));
ColumnBuffer += 4;
tempInputRow += 4;

MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(tempInputRow));
ColumnBuffer += 4;
tempInputRow += 4;

MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(tempInputRow));
ColumnBuffer += 4;
tempInputRow += 4;

InputX += 16;
CountCopyX -= 16;
}

while (CountCopyX >= 4) {
MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(&InputRow[InputX]));
Expand Down Expand Up @@ -233,22 +210,6 @@ Return Value:

MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4();

while (CountX >= 16) {
MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4);
ColumnBuffer += 4;

MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4);
ColumnBuffer += 4;

MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4);
ColumnBuffer += 4;

MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4);
ColumnBuffer += 4;

CountX -= 16;
}

while (CountX >= 4) {
MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4);
ColumnBuffer += 4;
Expand Down Expand Up @@ -566,8 +527,6 @@ Return Value:
const size_t FilterCount = Parameters->FilterCount;
const size_t OutputSize = Parameters->OutputSize;
const size_t K = Parameters->K;
const size_t K2 = Parameters->K * 2;
const size_t SegmentCountN2 = SegmentCountN * 2;

//
// Compute the strides to step through slices of the local segment.
Expand All @@ -580,14 +539,14 @@ Return Value:

if (SegmentCountN >= K) {

while (StrideK >= K2) {
while (StrideK / 2 >= K) {
StrideN *= 2;
StrideK /= 2;
}

} else {

while (StrideN > 16 && StrideN >= SegmentCountN2) {
while (StrideN > 16 && StrideN / 2 >= SegmentCountN) {
StrideK *= 2;
StrideN /= 2;
}
Expand All @@ -601,12 +560,11 @@ Return Value:

for (size_t n = 0; n < SegmentCountN; n += CountN) {

// CountN = SegmentCountN - n;
CountN = SegmentCountN - n;

CountN = SegmentCountN - n > StrideN ? StrideN : SegmentCountN - n;
// if (CountN > StrideN) {
// CountN = StrideN;
//}
if (CountN > StrideN) {
CountN = StrideN;
}

//
// Step through each slice of the input tensor along the K dimension.
Expand All @@ -618,12 +576,11 @@ Return Value:

for (size_t k = 0; k < K; k += CountK) {

// CountK = K - k;
CountK = K - k;

CountK = K - k > StrideK ? StrideK : K - k;
//if (CountK > StrideK) {
// CountK = StrideK;
//}
if (CountK > StrideK) {
CountK = StrideK;
}

if (Parameters->Dimensions == 2) {
MlasConvIm2Col(Parameters, Input, ColumnBuffer, k, CountK,
Expand Down

0 comments on commit ef3dec8

Please sign in to comment.