1
1
import os
2
- from typing import List
2
+ import sys
3
+ from typing import Callable , List , Optional
3
4
4
5
import numpy as np
5
6
10
11
logger = create_logger (__name__ )
11
12
12
13
14
+ class EpochResult (object ):
15
+ def __init__ (self , file_name : str ):
16
+ self .file_name = file_name
17
+
18
+ def load (self ) -> List [List [ResultRecord ]]:
19
+ return load_records_from_file (self .file_name )
20
+
21
+ @staticmethod
22
+ def losses (records : List [List [ResultRecord ]]) -> List [List [float ]]:
23
+ return [[r .loss for r in rec ] for rec in records ]
24
+
25
+ @staticmethod
26
+ def metric (records : List [List [ResultRecord ]]) -> List [List [float ]]:
27
+ return [[r .metric for r in rec ] for rec in records ]
28
+
29
+ @staticmethod
30
+ def task_name (records : List [List [ResultRecord ]]) -> List [str ]:
31
+ return [r .name for r in records [0 ]]
32
+
33
+ @staticmethod
34
+ def metric_name (records : List [List [ResultRecord ]]) -> List [str ]:
35
+ return [r .metric_name for r in records [0 ]]
36
+
37
+ @staticmethod
38
+ def reduce (
39
+ values : List [List [float ]],
40
+ reduce_task : Optional [Callable ] = np .mean ,
41
+ reduce_iter : Optional [Callable ] = np .mean ,
42
+ ) -> np .ndarray :
43
+ if reduce_task is None and reduce_iter is None :
44
+ raise ValueError ("Must reduce on something" )
45
+
46
+ if reduce_task is not None :
47
+ values = [reduce_task (np .asarray (vv ), axis = - 1 ) for vv in values ]
48
+
49
+ if reduce_iter is not None :
50
+ values = reduce_iter (np .asarray (values ), axis = 0 )
51
+
52
+ return values
53
+
54
+
13
55
class ExperimentResult (object ):
14
56
def __init__ (self , config : ExperimentConfig , output_dir : str ):
15
57
self .config = config
16
58
self .output_dir = output_dir
17
- self ._records_dir = os .path .join (self .output_dir , "train" )
59
+ self ._records_dir_train = os .path .join (self .output_dir , "train" )
60
+ self ._records_dir_eval = os .path .join (self .output_dir , "evaluation" )
18
61
19
62
def best_epoch (self ) -> int :
20
- losses = []
21
- for epoch in range (1 , self .config .trainer .epochs + 1 ):
22
- try :
23
- file_name = os .path .join (self ._records_dir , f"eval-{ epoch } " )
24
- records_valid = load_records_from_file (file_name )
25
- loss = np .asarray (
26
- [self ._records_loss (rs ) for rs in records_valid ]
27
- ).mean ()
28
- losses .append (loss )
29
- except FileNotFoundError :
30
- logger .warning (
31
- f"Training did not complete { epoch - 1 } /{ self .config .trainer .epochs } "
32
- )
33
- break
63
+ losses = self .metric ("train" , EpochResult .losses )
34
64
35
65
indexes = np .argsort (np .asarray (losses ))
36
66
index = indexes [0 ]
@@ -40,5 +70,38 @@ def best_epoch(self) -> int:
40
70
logger .info (f"Found the best epoch to be { epoch } with valid loss { valid_loss } " )
41
71
return epoch
42
72
43
- def _records_loss (self , records : List [ResultRecord ]) -> float :
44
- return np .asarray ([r .loss for r in records ]).mean ()
73
+ def records (self , tag : str , train : bool = True ) -> List [EpochResult ]:
74
+ records_dir = self ._records_dir_train if train else self ._records_dir_eval
75
+
76
+ results : List [EpochResult ] = []
77
+ for epoch in range (1 , sys .maxsize ):
78
+ file_name = os .path .join (records_dir , f"{ tag } -{ epoch } " )
79
+ if not os .path .exists (file_name ):
80
+ break
81
+
82
+ results .append (EpochResult (file_name ))
83
+
84
+ return results
85
+
86
+ def task_names (self , tag : str , train : bool = True ) -> List [str ]:
87
+ e_records = self .records (tag , train = train )[0 ]
88
+ return EpochResult .task_name (e_records .load ())
89
+
90
+ def metric_names (self , tag : str , train : bool = True ) -> List [str ]:
91
+ e_records = self .records (tag , train )[0 ]
92
+ return EpochResult .metric_name (e_records .load ())
93
+
94
+ def metric (
95
+ self , tag : str , metric , reduce_task = np .mean , reduce_iter = np .mean , train = True
96
+ ) -> np .ndarray :
97
+ e_records = self .records (tag , train = train )
98
+ return np .asarray (
99
+ [
100
+ EpochResult .reduce (
101
+ metric (records .load ()),
102
+ reduce_task = reduce_task ,
103
+ reduce_iter = reduce_iter ,
104
+ )
105
+ for records in e_records
106
+ ]
107
+ )
0 commit comments