From 42658bfb412cfcd279e19b076a5028685f8d234a Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Mon, 11 Mar 2024 10:58:33 -0400 Subject: [PATCH] Remove old API calls --- experiments/ljubljana/register.py | 1 - experiments/ljubljana/train.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/experiments/ljubljana/register.py b/experiments/ljubljana/register.py index 4bcdb92..49d481b 100644 --- a/experiments/ljubljana/register.py +++ b/experiments/ljubljana/register.py @@ -10,7 +10,6 @@ from torchvision.transforms.functional import resize from tqdm import tqdm -from diffpose.calibration import RigidTransform, convert from diffpose.ljubljana import Evaluator, LjubljanaDataset, Transforms from diffpose.metrics import DoubleGeodesic, GeodesicSE3 from diffpose.registration import PoseRegressor, SparseRegistration diff --git a/experiments/ljubljana/train.py b/experiments/ljubljana/train.py index c0c06e8..ca66035 100644 --- a/experiments/ljubljana/train.py +++ b/experiments/ljubljana/train.py @@ -85,12 +85,12 @@ def train( try: offset = get_random_offset(view, batch_size, device) pose = isocenter_pose.compose(offset) - img = drr(None, None, None, pose=pose) + img = drr(pose) img = transforms(img) pred_offset = model(img) pred_pose = isocenter_pose.compose(pred_offset) - pred_img = drr(None, None, None, pose=pred_pose) + pred_img = drr(pred_pose) pred_img = transforms(pred_img) ncc = metric(pred_img, img)