diff --git a/experiments/deepfluoro/register.py b/experiments/deepfluoro/register.py index 4861950..89d9d2f 100644 --- a/experiments/deepfluoro/register.py +++ b/experiments/deepfluoro/register.py @@ -7,10 +7,10 @@ import torch from diffdrr.drr import DRR from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d +from diffdrr.pose import RigidTransform, convert from torchvision.transforms.functional import resize from tqdm import tqdm -from diffpose.calibration import RigidTransform, convert from diffpose.deepfluoro import DeepFluoroDataset, Evaluator, Transforms from diffpose.metrics import DoubleGeodesic, GeodesicSE3 from diffpose.registration import PoseRegressor, SparseRegistration @@ -38,7 +38,7 @@ def __init__( self.geodesics = GeodesicSE3() self.doublegeo = DoubleGeodesic(sdr=self.specimen.focal_len / 2) - self.criterion = MultiscaleNormalizedCrossCorrelation2d([None, 13], [0.5, 0.5]) + self.criterion = MultiscaleNormalizedCrossCorrelation2d([None, 9], [0.5, 0.5]) self.transforms = Transforms(self.drr.detector.height) self.parameterization = parameterization self.convention = convention