-
Notifications
You must be signed in to change notification settings - Fork 2
/
contribution_analysis.py
41 lines (29 loc) · 1.23 KB
/
contribution_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python3
import json
import hydra
from cyy_naive_lib.log import get_logger
from cyy_torch_algorithm.hydra.hydra_config import HyDRAConfig
from util import analysis_contribution, save_image
config = HyDRAConfig()
other_config = None
@hydra.main(config_path="conf", version_base=None)
def load_config(conf):
global config
global other_config
if len(conf) == 1:
conf = next(iter(conf.values()))
other_config = HyDRAConfig.load_config(config, conf, check_config=False)
if __name__ == "__main__":
load_config()
trainer = config.create_trainer()
with open(other_config["contribution_path"], mode="rt", encoding="utf8") as f:
contribution_dict = {int(k): v for k, v in json.load(f).items()}
positive_contributions, negative_contributions = analysis_contribution(
contribution_dict, threshold=other_config["threshold"]
)
get_logger().info("positive contributions are %s", positive_contributions)
get_logger().info("negative contributions are %s", negative_contributions)
for k in positive_contributions:
save_image(".", trainer, positive_contributions, index=k)
for k in negative_contributions:
save_image(".", trainer, negative_contributions, index=k)