Skip to content

Commit

Permalink
FromVec broadcasting, increase torch version
Browse files Browse the repository at this point in the history
  • Loading branch information
zachteed committed Aug 29, 2021
1 parent 82b0223 commit 0fa9ce8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Zachary Teed and Jia Deng, CVPR 2021

### Requirements:
* Cuda >= 10.1 (with nvcc compiler)
* PyTorch >= 1.7
* PyTorch >= 1.8

We recommend installing within a virtual enviornment. Make sure you clone using the `--recursive` flag. If you are using Anaconda, the following command can be used to install all dependencies
```
Expand Down
4 changes: 2 additions & 2 deletions lietorch/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(cls, ctx, group_id, *inputs):
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J))
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)

class ToVec(torch.autograd.Function):
""" convert group object to vector """
Expand All @@ -98,5 +98,5 @@ def forward(cls, ctx, group_id, *inputs):
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), J)
return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)

2 changes: 1 addition & 1 deletion lietorch/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def fn(a):
return Group.InitFromVec(a).vec()

D = Group.embedded_dim
a = torch.randn(1, D, requires_grad=True, device=device).double()
a = torch.randn(1, 2, D, requires_grad=True, device=device).double()

analytical, numerical = gradcheck(fn, [a], eps=1e-4)

Expand Down

0 comments on commit 0fa9ce8

Please sign in to comment.