Skip to content

Commit

Permalink
fix key not found; fix general_vfov_to_focal when batched input
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlinyi committed Apr 18, 2024
1 parent be8caf3 commit d54be73
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None):

print("Alternatively, inference a batch of images")
predictions = pf_model.inference_batch(img_bgr_list=[img_bgr, img_bgr, img_bgr])
breakpoint()
breakpoint()
16 changes: 7 additions & 9 deletions perspective2d/modeling/param_network/param_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,13 @@ def forward(self, predictions, batched_inputs=None):
param["pred_general_vfov"] = param["pred_vfov"]
if "pred_rel_focal" not in param:
param["pred_rel_focal"] = torch.FloatTensor(
[
general_vfov_to_focal(
to_numpy(param["pred_rel_cx"]),
to_numpy(param["pred_rel_cy"]),
1,
to_numpy(param["pred_general_vfov"]),
degree=True,
)
]
general_vfov_to_focal(
to_numpy(param["pred_rel_cx"]),
to_numpy(param["pred_rel_cy"]),
1,
to_numpy(param["pred_general_vfov"]),
degree=True,
)
)
return param

Expand Down
2 changes: 1 addition & 1 deletion perspective2d/perspectivefields.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def forward(self, batched_inputs) -> dict:
param["pred_rel_cx"] = torch.zeros_like(param["pred_vfov"])
if "pred_rel_cy" not in param.keys():
param["pred_rel_cy"] = torch.zeros_like(param["pred_vfov"])
assert len(processed_results) == len(param["pred_vfov"])
assert len(processed_results) == len(param["pred_general_vfov"])
for i in range(len(processed_results)):
param_tmp = {k: v[i] for k, v in param.items()}
processed_results[i].update(param_tmp)
Expand Down
8 changes: 6 additions & 2 deletions perspective2d/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ def fun(focal, *args):
q_sqr = (focal / h) ** 2 + d_cx**2 + (d_cy - 0.5) ** 2
cos_FoV = (p_sqr + q_sqr - 1) / 2 / np.sqrt(p_sqr) / np.sqrt(q_sqr)
return cos_FoV - target_cos_FoV

if degree:
gvfov = np.radians(gvfov)
focal = scipy.optimize.fsolve(fun, 1.5, args=(h, rel_cx, rel_cy, np.cos(gvfov)))[0]
if type(rel_cx) != np.ndarray:
# if input is float
focal = scipy.optimize.fsolve(fun, 1.5, args=(h, rel_cx, rel_cy, np.cos(gvfov)))[0]
else:
# if input is numpy array
focal = scipy.optimize.fsolve(fun, np.ones(len(rel_cx)) * 1.5, args=(h, rel_cx, rel_cy, np.cos(gvfov)))
focal = np.abs(focal)
return focal

Expand Down

0 comments on commit d54be73

Please sign in to comment.