-
Notifications
You must be signed in to change notification settings - Fork 1
/
find_unused_params.py
49 lines (33 loc) · 1.14 KB
/
find_unused_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import yaml
import torch
import argparse
torch.set_float32_matmul_precision("high")
from src.models import TDAVNet
from src.utils import parse_args_as_dict
from src.system import make_optimizer
from src.losses import PITLossWrapper, pairwise_neg_snr
x = torch.rand(2, 32000).to(0)
z = torch.rand(2, 1, 32000).to(0)
y = torch.rand(2, 512, 50).to(0)
def main(conf):
audiomodel = TDAVNet(**conf["audionet"]).to(0)
optimizer = make_optimizer(audiomodel.parameters(), **conf["optim"])
# Define Loss function.
loss_func = PITLossWrapper(pairwise_neg_snr, pit_from="pw_mtx").to(0)
optimizer.zero_grad()
z1 = audiomodel(x, y)
loss = loss_func(z1, z)
loss.backward()
for name, param in audiomodel.named_parameters():
if param.grad is None:
print(name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--conf-dir", default="config/lrs2_RTFSNet_4_layer.yaml")
args = parser.parse_args()
with open(args.conf_dir) as f:
def_conf = yaml.safe_load(f)
arg_dic = parse_args_as_dict(parser)
def_conf.update(arg_dic)
main(def_conf)