Skip to content

Commit 301f40d

Browse files
authored
Merge pull request #1970 from gdevenyi/fix-registration-image-cache-masks
ENH: Deduplicate all image reads in antsRegistration
2 parents 9e4f789 + aaf032b commit 301f40d

1 file changed

Lines changed: 39 additions & 40 deletions

File tree

Examples/antsRegistrationTemplateHeader.h

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,36 @@ DoRegistration(typename ParserType::Pointer & parser)
327327
}
328328
}
329329

330+
std::map<std::string, typename ImageType::Pointer> imageCache;
331+
332+
auto getCachedImage = [&imageCache](const std::string & filename) -> typename ImageType::Pointer {
333+
std::error_code ec;
334+
std::filesystem::path canonicalPath = std::filesystem::canonical(filename, ec);
335+
std::string key = ec ? filename : canonicalPath.string();
336+
typename ImageType::Pointer & cached = imageCache[key];
337+
if (cached.IsNull())
338+
{
339+
ReadImage<ImageType>(cached, filename.c_str());
340+
cached->DisconnectPipeline();
341+
}
342+
return cached;
343+
};
344+
345+
std::map<std::string, typename MaskImageType::Pointer> maskCache;
346+
347+
auto getCachedMask = [&maskCache](const std::string & filename) -> typename MaskImageType::Pointer {
348+
std::error_code ec;
349+
std::filesystem::path canonicalPath = std::filesystem::canonical(filename, ec);
350+
std::string key = ec ? filename : canonicalPath.string();
351+
typename MaskImageType::Pointer & cached = maskCache[key];
352+
if (cached.IsNull())
353+
{
354+
ReadImage<MaskImageType>(cached, filename.c_str());
355+
cached->DisconnectPipeline();
356+
}
357+
return cached;
358+
};
359+
330360
if (maskOption && maskOption->GetNumberOfFunctions())
331361
{
332362
if (verbose)
@@ -344,8 +374,7 @@ DoRegistration(typename ParserType::Pointer & parser)
344374
for (unsigned m = 0; m < maskOption->GetFunction(l)->GetNumberOfParameters(); m++)
345375
{
346376
std::string fname = maskOption->GetFunction(l)->GetParameter(m);
347-
typename MaskImageType::Pointer maskImage;
348-
ReadImage<MaskImageType>(maskImage, fname.c_str());
377+
typename MaskImageType::Pointer maskImage = getCachedMask(fname);
349378
if (m == 0)
350379
{
351380
regHelper->AddFixedImageMask(maskImage);
@@ -381,8 +410,7 @@ DoRegistration(typename ParserType::Pointer & parser)
381410
else
382411
{
383412
std::string fname = maskOption->GetFunction(l)->GetName();
384-
typename MaskImageType::Pointer maskImage;
385-
ReadImage<MaskImageType>(maskImage, fname.c_str());
413+
typename MaskImageType::Pointer maskImage = getCachedMask(fname);
386414
regHelper->AddFixedImageMask(maskImage);
387415
if (verbose)
388416
{
@@ -680,9 +708,7 @@ DoRegistration(typename ParserType::Pointer & parser)
680708

681709
if (meshSizeForTheUpdateField.size() == 1)
682710
{
683-
typename ImageType::Pointer fixedImage;
684-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
685-
fixedImage->DisconnectPipeline();
711+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
686712

687713
meshSizeForTheUpdateField =
688714
regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(fixedImage, meshSizeForTheUpdateField[0], splineOrder);
@@ -695,9 +721,7 @@ DoRegistration(typename ParserType::Pointer & parser)
695721
parser->ConvertVector<unsigned int>(transformOption->GetFunction(currentStage)->GetParameter(2));
696722
if (meshSizeForTheTotalField.size() == 1)
697723
{
698-
typename ImageType::Pointer fixedImage;
699-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
700-
fixedImage->DisconnectPipeline();
724+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
701725

702726
meshSizeForTheTotalField =
703727
regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(fixedImage, meshSizeForTheTotalField[0], splineOrder);
@@ -721,9 +745,7 @@ DoRegistration(typename ParserType::Pointer & parser)
721745
parser->ConvertVector<unsigned int>(transformOption->GetFunction(currentStage)->GetParameter(1));
722746
if (meshSizeAtBaseLevel.size() == 1)
723747
{
724-
typename ImageType::Pointer fixedImage;
725-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
726-
fixedImage->DisconnectPipeline();
748+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
727749

728750
meshSizeAtBaseLevel =
729751
regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(fixedImage, meshSizeAtBaseLevel[0], 3);
@@ -802,9 +824,7 @@ DoRegistration(typename ParserType::Pointer & parser)
802824
parser->ConvertVector<float>(transformOption->GetFunction(currentStage)->GetParameter(1));
803825
if (meshSizeForTheUpdateFieldFloat.size() == 1)
804826
{
805-
typename ImageType::Pointer fixedImage;
806-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
807-
fixedImage->DisconnectPipeline();
827+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
808828

809829
meshSizeForTheUpdateField = regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(
810830
fixedImage, meshSizeForTheUpdateFieldFloat[0], splineOrder);
@@ -822,9 +842,7 @@ DoRegistration(typename ParserType::Pointer & parser)
822842
parser->ConvertVector<float>(transformOption->GetFunction(currentStage)->GetParameter(2));
823843
if (meshSizeForTheTotalFieldFloat.size() == 1)
824844
{
825-
typename ImageType::Pointer fixedImage;
826-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
827-
fixedImage->DisconnectPipeline();
845+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
828846

829847
meshSizeForTheTotalField = regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(
830848
fixedImage, meshSizeForTheTotalFieldFloat[0], splineOrder);
@@ -876,9 +894,7 @@ DoRegistration(typename ParserType::Pointer & parser)
876894
parser->ConvertVector<unsigned int>(transformOption->GetFunction(currentStage)->GetParameter(1));
877895
if (meshSizeForTheUpdateField.size() == 1)
878896
{
879-
typename ImageType::Pointer fixedImage;
880-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
881-
fixedImage->DisconnectPipeline();
897+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
882898

883899
meshSizeForTheUpdateField =
884900
regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(fixedImage, meshSizeForTheUpdateField[0], splineOrder);
@@ -891,9 +907,7 @@ DoRegistration(typename ParserType::Pointer & parser)
891907
parser->ConvertVector<unsigned int>(transformOption->GetFunction(currentStage)->GetParameter(2));
892908
if (meshSizeForTheVelocityField.size() == 1)
893909
{
894-
typename ImageType::Pointer fixedImage;
895-
ReadImage<ImageType>(fixedImage, fixedImageFileName.c_str());
896-
fixedImage->DisconnectPipeline();
910+
typename ImageType::Pointer fixedImage = getCachedImage(fixedImageFileName);
897911

898912
meshSizeForTheVelocityField = regHelper->CalculateMeshSizeForSpecifiedKnotSpacing(
899913
fixedImage, meshSizeForTheVelocityField[0], splineOrder);
@@ -947,21 +961,6 @@ DoRegistration(typename ParserType::Pointer & parser)
947961
// ID to the added metric. Multiple metrics for a single stage are specified
948962
// on the command line by being specified adjacently.
949963

950-
std::map<std::string, typename ImageType::Pointer> imageCache;
951-
952-
auto getCachedImage = [&imageCache](const std::string & filename) -> typename ImageType::Pointer {
953-
std::error_code ec;
954-
std::filesystem::path canonicalPath = std::filesystem::canonical(filename, ec);
955-
std::string key = ec ? filename : canonicalPath.string();
956-
typename ImageType::Pointer & cached = imageCache[key];
957-
if (cached.IsNull())
958-
{
959-
ReadImage<ImageType>(cached, filename.c_str());
960-
cached->DisconnectPipeline();
961-
}
962-
return cached;
963-
};
964-
965964
unsigned int numberOfMetrics = metricOption->GetNumberOfFunctions();
966965
for (int currentMetricNumber = numberOfMetrics - 1; currentMetricNumber >= 0; currentMetricNumber--)
967966
{

0 commit comments

Comments
 (0)