-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_time_mem.py
115 lines (93 loc) · 3.43 KB
/
test_time_mem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import time
import numpy as np
import torch
from torch_efficient_distloss import eff_distloss, eff_distloss_native, flatten_eff_distloss
def original_distloss(w, m, interval):
'''
Original O(N^2) realization of distortion loss.
There are B rays each with N sampled points.
w: Float tensor in shape [B,N]. Volume rendering weights of each point.
m: Float tensor in shape [B,N]. Midpoint distance to camera of each point.
interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
'''
loss_uni = (1/3) * (interval * w.pow(2)).sum(-1).mean()
ww = w.unsqueeze(-1) * w.unsqueeze(-2) # [B,N,N]
mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [B,N,N]
loss_bi = (ww * mm).sum((-1,-2)).mean()
return loss_uni + loss_bi
def gen_example(B, N):
w = torch.rand(B, N).cuda()
w = w / w.sum(-1, keepdim=True)
w = w.clone().requires_grad_()
s = torch.linspace(0, 1, N+1).cuda()
m = (s[1:] + s[:-1]) * 0.5
m = m[None].repeat(B,1)
interval = 1/N
return w, m, interval
def spec(f, NTIMES, *args):
ts_forward = []
ts_backward = []
for i in range(1+NTIMES):
torch.cuda.empty_cache()
torch.cuda.synchronize()
s_time = time.time()
loss = f(*args)
torch.cuda.synchronize()
e_time = time.time()
if i>0:
ts_forward.append(e_time - s_time)
torch.cuda.synchronize()
s_time = time.time()
loss.backward()
torch.cuda.synchronize()
e_time = time.time()
if i>0:
ts_backward.append(e_time - s_time)
del loss
#print(f'forward : {np.sum(ts_forward):.1f} sec.')
#print(f'backward: {np.sum(ts_backward):.1f} sec.')
print(f'total : {np.sum(ts_forward) + np.sum(ts_backward):.1f} sec.')
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
loss = f(*args)
torch.cuda.synchronize()
mem_forward = torch.cuda.max_memory_allocated()
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
loss.backward()
torch.cuda.synchronize()
mem_backward = torch.cuda.max_memory_allocated()
del loss
#print('forward :', mem_forward/1024/1024, 'MB.')
#print('backward:', mem_backward/1024/1024, 'MB.')
print(f'total : {max(mem_forward,mem_backward)/1024/1024:.0f} MB.')
if __name__ == '__main__':
# B rays N points
B = 8192
NTIMES = 100
for N in [32, 64, 128, 256, 384, 512]:
print(f' B={B}; N={N} '.center(50, '='))
w, m, interval = gen_example(B, N)
ray_id = torch.arange(len(w))[:,None].repeat(1,N).cuda()
try:
print(' original_distloss '.center(50, '.'))
spec(original_distloss, NTIMES, w, m, interval)
except RuntimeError as e:
print(e)
try:
print(' eff_distloss_native '.center(50, '.'))
spec(eff_distloss_native, NTIMES, w, m, interval)
except RuntimeError as e:
print(e)
try:
print(' eff_distloss '.center(50, '.'))
spec(eff_distloss, NTIMES, w, m, interval)
except RuntimeError as e:
print(e)
try:
print(' flatten_eff_distloss '.center(50, '.'))
spec(flatten_eff_distloss, NTIMES, w.flatten(), m.flatten(), interval, ray_id.flatten())
except RuntimeError as e:
print(e)