Skip to content

Commit

Permalink
Improve ARPS performance
Browse files Browse the repository at this point in the history
  • Loading branch information
tjof2 committed Jun 17, 2020
1 parent 1cb4081 commit 2be26fa
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 85 deletions.
153 changes: 70 additions & 83 deletions src/arps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,29 @@ namespace pguresvt
uint32_t Nx, Ny, Nt;
uint32_t nxMbs, nyMbs, vecSize;
double OoBlockSizeSq;

const double costsScale = 1E8;
const uint32_t maxSDSP = 1E6;

inline double CostFunction(const arma::Cube<T> &A, const arma::cube &B)
{
return arma::accu(arma::square(A - B)) * OoBlockSizeSq;
}

void ARPSMotionEstimation(const int curFrame, const int iARPS1, const int iARPS2, const int iARPS3)
{
double norm = 0;
arma::umat checkMat = arma::zeros<arma::umat>(2 * motionWindow + 1, 2 * motionWindow + 1);

arma::vec::fixed<6> costs;
arma::imat::fixed<6, 2> LDSP;
arma::imat::fixed<5, 2> SDSP;

arma::Cube<T> refBlock, newBlock, powBlock;

refBlock.set_size(blockSize, blockSize, 1);
newBlock.set_size(blockSize, blockSize, 1);
powBlock.set_size(blockSize, blockSize, 1);

for (size_t it = 0; it < vecSize; it++)
{
costs.fill(costsScale);
Expand All @@ -184,17 +196,15 @@ namespace pguresvt
int x = j;
int y = i;

arma::Cube<T> refBlock = A(arma::span(i, i + blockSize - 1),
arma::span(j, j + blockSize - 1),
arma::span(iARPS1));
refBlock = A(arma::span(i, i + blockSize - 1),
arma::span(j, j + blockSize - 1),
arma::span(iARPS1));

arma::Cube<T> newBlock = A(arma::span(i, i + blockSize - 1),
arma::span(j, j + blockSize - 1),
arma::span(iARPS2));

norm = arma::norm(refBlock.slice(0) - newBlock.slice(0), "fro");
costs(2) = norm * norm * OoBlockSizeSq;
newBlock = A(arma::span(i, i + blockSize - 1),
arma::span(j, j + blockSize - 1),
arma::span(iARPS2));

costs(2) = CostFunction(refBlock, newBlock);
checkMat(motionWindow, motionWindow) = 1;

uint32_t maxIdx;
Expand All @@ -209,8 +219,10 @@ namespace pguresvt
{
int yTmp = std::abs(motions(0, it, iARPS3));
int xTmp = std::abs(motions(1, it, iARPS3));

stepSize = (xTmp <= yTmp) ? yTmp : xTmp;
if ((yTmp == 0 && xTmp == stepSize) || (xTmp == 0 && yTmp == stepSize))

if (((yTmp == 0) && (xTmp == stepSize)) || ((xTmp == 0) && (yTmp == stepSize)))
{
maxIdx = 5;
}
Expand All @@ -235,7 +247,7 @@ namespace pguresvt

// Currently not used, but motion estimation can be
// predictive if this value is larger than 0
double pMotion = 0.0;
double pMotion = -1.0;

bool skipIt = false;

Expand All @@ -253,46 +265,31 @@ namespace pguresvt

if (!skipIt) // Only evaluate if none of the above is true
{
arma::Cube<T> powBlock = A(arma::span(refBlkVer, refBlkVer + blockSize - 1),
arma::span(refBlkHor, refBlkHor + blockSize - 1),
arma::span(iARPS2));
if (curFrame == 0)
powBlock = A(arma::span(refBlkVer, refBlkVer + blockSize - 1),
arma::span(refBlkHor, refBlkHor + blockSize - 1),
arma::span(iARPS2));

costs(k) = CostFunction(refBlock, powBlock);

if ((pMotion > 0.0) && (curFrame < 0))
{
norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) = norm * norm * OoBlockSizeSq;
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) -
motions(arma::span(), arma::span(it), arma::span(iARPS3)));
costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}
else if (curFrame < 0)
else if ((pMotion > 0.0) && (curFrame > 0))
{
norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) = norm * norm * OoBlockSizeSq;

if (pMotion > 0.0)
{
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) -
motions(arma::span(), arma::span(it), arma::span(iARPS3)));
costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) +
motions(arma::span(), arma::span(it), arma::span(iARPS3)));
costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}
else if (curFrame > 0)
{
norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) = norm * norm * OoBlockSizeSq;

if (pMotion > 0.0)
{
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) +
motions(arma::span(), arma::span(it), arma::span(iARPS3)));

norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}
}

checkMat(LDSP(k, 1) + motionWindow, LDSP(k, 0) + motionWindow) = 1;
checkMat(LDSP(k, 1) + motionWindow,
LDSP(k, 0) + motionWindow) = 1;
}
}

Expand All @@ -304,6 +301,7 @@ namespace pguresvt
costs(2) = cost;

bool doneFlag = false;
uint32_t nSDSP = 0;

do // Do the SDSP
{
Expand All @@ -328,53 +326,39 @@ namespace pguresvt

if (!skipIt) // Only evaluate if none of the above is true
{
arma::Cube<T> powBlock = A(arma::span(refBlkVer, refBlkVer + blockSize - 1),
arma::span(refBlkHor, refBlkHor + blockSize - 1),
arma::span(iARPS2));
if (curFrame == 0)
{
norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) = norm * norm * OoBlockSizeSq;
}
else if (curFrame < 0)
powBlock = A(arma::span(refBlkVer, refBlkVer + blockSize - 1),
arma::span(refBlkHor, refBlkHor + blockSize - 1),
arma::span(iARPS2));

costs(k) = CostFunction(refBlock, powBlock);

if ((pMotion > 0.0) && (curFrame < 0))
{
norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) = norm * norm * OoBlockSizeSq;

if (pMotion > 0.0)
{
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) -
motions(arma::span(), arma::span(it), arma::span(iARPS3)));

costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}

arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) -
motions(arma::span(), arma::span(it), arma::span(iARPS3)));
costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}
else if (curFrame > 0)
else if ((pMotion > 0.0) && (curFrame > 0))
{
norm = arma::norm(refBlock.slice(0) - powBlock.slice(0), "fro");
costs(k) = norm * norm * OoBlockSizeSq;

if (pMotion > 0.0)
{
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) +
motions(arma::span(), arma::span(it), arma::span(iARPS3)));

costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}
arma::ivec predPos = arma::vectorise(
patches(arma::span(), arma::span(it), arma::span(iARPS1)) +
motions(arma::span(), arma::span(it), arma::span(iARPS3)));
costs(k) += pMotion * std::sqrt(std::pow(predPos(0) - refBlkVer, 2) +
std::pow(predPos(1) - refBlkHor, 2));
}

checkMat(y - i + SDSP(k, 1) + motionWindow, x - j + SDSP(k, 0) + motionWindow) = 1;
checkMat(y - i + SDSP(k, 1) + motionWindow,
x - j + SDSP(k, 0) + motionWindow) = 1;
}
}

point = arma::find(costs == costs.min());
cost = costs.min();

if (point(0) == 2)
if ((point(0) == 2) || (nSDSP >= maxSDSP))
{
doneFlag = true;
}
Expand All @@ -385,6 +369,9 @@ namespace pguresvt
costs.fill(costsScale);
costs(2) = cost;
}

nSDSP++;

} while (!doneFlag);

motions(0, it, iARPS3) = y - i;
Expand Down
6 changes: 4 additions & 2 deletions src/pguresvt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ uint32_t PGURESVT(arma::Cube<T2> &Y,
w = arma::conv_to<arma::Cube<T2>>::from(Z.slices(timeIter - frameWindow, timeIter + frameWindow));
}

double uMax = u.max(); // Basic sequence normalization
double uMax = u.max(); // Sequence normalization
double wMax = w.max(); // ARPS relies on reasonably well-scaled data

u /= uMax;
w /= uMax;
w /= wMax;

if (optPGURE) // Perform noise estimation
{
Expand Down

0 comments on commit 2be26fa

Please sign in to comment.