diff --git a/experiments/ljubljana/register.py b/experiments/ljubljana/register.py index 4bcdb92..49d481b 100644 --- a/experiments/ljubljana/register.py +++ b/experiments/ljubljana/register.py @@ -10,7 +10,6 @@ from torchvision.transforms.functional import resize from tqdm import tqdm -from diffpose.calibration import RigidTransform, convert from diffpose.ljubljana import Evaluator, LjubljanaDataset, Transforms from diffpose.metrics import DoubleGeodesic, GeodesicSE3 from diffpose.registration import PoseRegressor, SparseRegistration diff --git a/experiments/ljubljana/train.py b/experiments/ljubljana/train.py index c0c06e8..ca66035 100644 --- a/experiments/ljubljana/train.py +++ b/experiments/ljubljana/train.py @@ -85,12 +85,12 @@ def train( try: offset = get_random_offset(view, batch_size, device) pose = isocenter_pose.compose(offset) - img = drr(None, None, None, pose=pose) + img = drr(pose) img = transforms(img) pred_offset = model(img) pred_pose = isocenter_pose.compose(pred_offset) - pred_img = drr(None, None, None, pose=pred_pose) + pred_img = drr(pred_pose) pred_img = transforms(pred_img) ncc = metric(pred_img, img)