-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplattice.py
50 lines (42 loc) · 1.56 KB
/
plattice.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import PLOO
from torch.nn import Module
class PermutoLatticeFunction(Function):
@staticmethod
def forward(ctx, feature, values):
weight, out = PLOO.forward(feature, values)
ctx.save_for_backward(weight, feature)
return out
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
weight, feature = ctx.saved_tensors
grad_values = PLOO.backward(feature, grad_output.contiguous(), weight)
return None, grad_values # no need to back propogate features
class PermutoLattice(Module):
def forward(self, feature, values):
return PermutoLatticeFunction.apply(feature, values)
if __name__ == "__main__":
# check the gradient of the function
ft = torch.randn(5,5,3,dtype=torch.float32,requires_grad=False).cuda()
v = torch.randn(5,5, 3,dtype=torch.float32,requires_grad=True).cuda()
v.retain_grad()
v_dv = torch.randn(5,5,3, dtype=torch.float32,requires_grad=True).cuda()
v_dv.retain_grad()
# weight, out = PLOO.forward(ft, v)
# __, g_out = PLOO.forward(ft, v_dv)
# grad_v = PLOO.backward(ft, torch.ones_like(out), weight)
# _, out2 = PLOO.forward(ft, v+v_dv)
# grad_cha = grad_v * g_out
# real_cha = out2 - out
# cha = grad_cha - real_cha
# print(torch.mean(cha))
plattice = PermutoLatticeFunction.apply
c = v+v_dv
c.retain_grad()
a = plattice(ft, c)
loss = 1 - torch.sum(a)
loss.backward()
print(c.grad)