|
11 | 11 |
|
12 | 12 | @dataclass
|
13 | 13 | class TrainConfig:
|
14 |
| - """Class which stores all training configuration""" |
| 14 | + """Class which stores all training configurations""" |
15 | 15 |
|
16 |
| - # div factor is a constant which can be used to reduce the batch size and learning rate respectively |
17 |
| - # use a value higher 1 if you encounter memory allocation errors |
18 |
| - div_factor: int = 2 |
| 16 | + info_div_factor: str = "div factor is a constant which can be used to reduce the batch size and learning rate" \ |
| 17 | + " respectively use a value higher 1 if you encounter memory allocation errors" |
| 18 | + div_factor: int = 1 |
19 | 19 |
|
20 |
| - # 1024 # the batch_size needed to be reduced to 1024 in order to fit in the GPU 1080Ti |
21 |
| - # 4096 was originally used in the paper -> works slower for current GPU |
22 |
| - # 2048 was used in the paper Mastering the game of Go without human knowledge and fits in GPU memory |
23 |
| - # typically if you half the batch_size you should double the lr |
| 20 | + info_batch_size: str = "batch size used during training. The batch-size may need to be reduced in order to fit on" \ |
| 21 | + " your GPU memory. 4096 was originally used in the paper, 2048 was used in the paper" \ |
| 22 | + " 'Mastering the game of Go without human knowledge'. Typically if you half the batch_size" \ |
| 23 | + " you should double the learning rate." |
24 | 24 | batch_size: int = int(1024 / div_factor)
|
25 | 25 |
|
26 |
| - # batch_steps = 1000 means for example that every 1000 batches the validation set gets processed |
27 |
| - # this defines how often a new checkpoint will be saved and the metrics evaluated |
28 |
| - batch_steps: int = 100 * div_factor |
| 26 | + info_batch_steps: str = "batch_steps = 1000 means for example that every 1000 batches the validation set is" \ |
| 27 | + " processed. It defines how often a new checkpoint will be saved and the metrics evaluated" |
| 28 | + batch_steps: int = 1000 * div_factor |
29 | 29 |
|
30 |
| - # set the context on CPU switch to GPU if there is one available (strongly recommended for training) |
| 30 | + info_context: str = "context defines the computation device to use for training. Set the context to to 'gpu' if" \ |
| 31 | + " there is one available, otherwise you may train on 'cpu' instead." |
31 | 32 | context: str = "gpu"
|
32 | 33 |
|
| 34 | + info_cpu_count: str = "cpu_count defines the number of cpu cores to use for data processing while training." |
33 | 35 | cpu_count: int = 4
|
34 | 36 |
|
| 37 | + info_device_id: str = "device_id sets the GPU device to use for training." |
35 | 38 | device_id: int = 0
|
36 | 39 |
|
| 40 | + info_discount: str = "discount describes the discounting value to use for discounting the value target " \ |
| 41 | + "until reaching the final terminal value." |
37 | 42 | discount: float = 1.0
|
38 | 43 |
|
| 44 | + info_dropout_rate: str = "dropout_rate describes the drobout percentage as used in the neural network architecture." |
39 | 45 | dropout_rate: float = 0
|
40 | 46 |
|
41 |
| - # directory to write and read weight, log, onnx and other export files |
| 47 | + info_export_dir: str = "export_dir sets the directory to write and read weights, log, onnx and other export logging" \ |
| 48 | + " files" |
42 | 49 | export_dir: str = "./"
|
43 | 50 |
|
| 51 | + info_export_weights: str = "export_weights is a boolean to decide if the neural network weights should be exported" \ |
| 52 | + "during training." |
44 | 53 | export_weights: bool = True
|
45 | 54 |
|
| 55 | + info_export_grad_histograms: str = "export_grad_histograms enables or disable the export of gradient diagrams " \ |
| 56 | + "during training." |
46 | 57 | export_grad_histograms: bool = True
|
47 | 58 |
|
48 |
| - # Decide between 'pytorch', 'mxnet' and 'gluon' style for training |
49 |
| - # Reinforcement Learning only works with gluon and pytorch atm |
| 59 | + info_framework: str = "framework sets the deep learning framework to use. Currently only 'pytorch' is available." \ |
| 60 | + "mxnet and gluon have been deprecated." |
50 | 61 | framework: str = 'pytorch'
|
51 | 62 |
|
52 |
| - # Boolean if the policy data is also defined in select_policy_from_plane representation |
| 63 | + info_is_policy_from_plane_data: str = "is_policy_from_plane_data is a boolean to decide if the policy data is" \ |
| 64 | + " already defined in select_policy_from_plane / plane representation." |
53 | 65 | is_policy_from_plane_data: bool = False
|
54 | 66 |
|
| 67 | + info_log_metrics_to_tensorboard: str = "log_metrics_to_tensorboard decides if the metrics should be exported with" \ |
| 68 | + "tensorboard." |
55 | 69 | log_metrics_to_tensorboard: bool = True
|
56 | 70 |
|
57 |
| - # k_steps_initial defines how many steps have been trained before |
58 |
| - # (k_steps_initial != 0 if you continue training from a checkpoint) |
| 71 | + info_model_type: str = "model_type defines the Model type that used during training (e.g. resnet, vit, risev2," \ |
| 72 | + " risev3, alphavile, alphavile-tiny, alphavile-small, alphavile-normal, alphavile-large," \ |
| 73 | + " NextViT)" |
| 74 | + model_type: str = "resnet" |
| 75 | + |
| 76 | + info_k_steps_initial: str = "k_steps_initial defines how many steps have been trained before (k_steps_initial != 0 if" \ |
| 77 | + " you continue training from a checkpoint)" \ |
| 78 | + " (TODO: Continuing training from a previous checkpoint is currently not available in" \ |
| 79 | + " pytorch training loop.)" |
59 | 80 | k_steps_initial: int = 0
|
60 |
| - # these are the weights to continue training with |
61 |
| - # symbol_file = 'model_init-symbol.json' # model-1.19246-0.603-symbol.json' |
62 |
| - # tar_file = 'model_init-0000.params' # model-1.19246-0.603-0223.params' |
63 |
| - symbol_file: str = '' |
| 81 | + |
| 82 | + info_tar_file: str = "tar_file is the neural network weight file to continue training with" |
64 | 83 | tar_file: str = ''
|
65 | 84 |
|
66 |
| - # # optimization parameters |
| 85 | + info_optimizer_name: str = "optimizer_name is the optimizer that used in the training loop to update the weights." \ |
| 86 | + "(e.g. nag, sgd, adam, adamw)" |
67 | 87 | optimizer_name: str = "nag"
|
68 |
| - max_lr: float = 0.1 / div_factor # 0.35 / div_factor |
69 |
| - min_lr: float = 0.00001 / div_factor # 0.2 / div_factor # 0.00001 |
| 88 | + |
| 89 | + info_max_lr: str = "max_lr defines the maximum learning rate used for training." |
| 90 | + max_lr: float = 0.07 / div_factor |
| 91 | + info_min_lr: str = "min_lr defines the minimum learning rate used for training." |
| 92 | + min_lr: float = 0.00001 / div_factor |
| 93 | + |
| 94 | + if "adam" in optimizer_name: |
| 95 | + max_lr = 0.001001 # 1e-3 |
| 96 | + min_lr = 0.001 |
| 97 | + |
| 98 | + info_max_momentum: str = "max_momentum defines the maximum momentum factor used during training (only applicable to" \ |
| 99 | + "optimizers that are momentum based)" |
70 | 100 | max_momentum: float = 0.95
|
| 101 | + info_min_momentum: str = "min_momentum defines the minimum momentum factor used during training (only applicable to" \ |
| 102 | + "optimizers that are momentum based)" |
71 | 103 | min_momentum: float = 0.8
|
72 |
| - # stop training as soon as max_spikes has been reached |
| 104 | + |
| 105 | + info_max_spikes: str = "max_spikes defines the maximum number of spikes. Training is stopped as soon as max_spikes" \ |
| 106 | + " has been reached." |
73 | 107 | max_spikes: int = 20
|
74 | 108 |
|
75 | 109 | # name initials which are used to identify running training processes with rtpt
|
76 | 110 | # prefix for the process name in order to identify the process on a server
|
| 111 | + info_name_initials: str = "name_initials sets the name initials which are used to identify running training" \ |
| 112 | + " processes with rtpt. It is used as a prefix for the process name in order to identify" \ |
| 113 | + " the process on a server." |
77 | 114 | name_initials: str = "XX"
|
78 | 115 |
|
| 116 | + info_nb_parts: str = "nb_parts sets the number of training zip files used for training. This value is normally " \ |
| 117 | + "dynamically set before training based on the number of .zip files available in the training " \ |
| 118 | + "directory." |
79 | 119 | nb_parts: int = None
|
80 | 120 |
|
| 121 | + info_normalize: str = "normalize decides if the training data should be normalized to the range of [0,1]." |
81 | 122 | normalize: bool = True # define whether to normalize input data to [01]
|
82 | 123 |
|
83 |
| - # how many epochs the network will be trained each time there is enough new data available |
84 |
| - nb_training_epochs: int = 1 |
| 124 | + info_nb_training_epochs: str = "nb_training_epochs defines how many epoch iterations the network will be trained." |
| 125 | + nb_training_epochs: int = 7 |
85 | 126 |
|
86 |
| - policy_loss_factor: float = 0.5 # 0.99 |
| 127 | + info_plys_to_end_loss_factor: str = "plys_to_end_loss_factor defines the gradient scaling for the plys to end" \ |
| 128 | + " output." |
| 129 | + plys_to_end_loss_factor: float = 0.002 |
87 | 130 |
|
88 |
| - # gradient scaling for the plys to end output |
89 |
| - plys_to_end_loss_factor: float = 0.1 |
| 131 | + info_q_value_ratio: str = "q_value_ratio defines the ratio for mixing the value return with the corresponding " \ |
| 132 | + "q-value for a ratio of 0 no q-value information will be used." |
| 133 | + q_value_ratio: float = 0.0 |
90 | 134 |
|
91 |
| - # ratio for mixing the value return with the corresponding q-value |
92 |
| - # for a ratio of 0 no q-value information will be used |
93 |
| - q_value_ratio: float = 0.15 |
94 |
| - |
95 |
| - # set a specific seed value for reproducibility |
| 135 | + info_seed: str = "seed sets a specific seed value for reproducibility." |
96 | 136 | seed: int = 42
|
97 | 137 |
|
98 |
| - # Boolean if potential legal moves will be selected from final policy output |
| 138 | + info_select_policy_from_plane: str = "select_policy_from_plan defines if potential legal moves will be selected" \ |
| 139 | + " from final policy output in plane representation / convolution " \ |
| 140 | + "representation rather than a flat representation." |
99 | 141 | select_policy_from_plane: bool = True
|
100 | 142 |
|
101 |
| - # define spike threshold when the detection will be triggered |
| 143 | + info_spike_thresh: str = "spike_thresh defines the spike threshold when the detection will be triggered. It is" \ |
| 144 | + " triggered when last_loss x spike_thresh < current_loss." |
102 | 145 | spike_thresh: float = 1.5
|
103 | 146 |
|
104 |
| - # Boolean if the policy target is one-hot encoded (sparse=True) or a target distribution (sparse=False) |
105 |
| - sparse_policy_label: bool = False |
| 147 | + info_sparse_policy_label: str = "sparse_policy_label defines if the policy target is one-hot encoded (sparse=True)" \ |
| 148 | + " or a target distribution (sparse=False)" |
| 149 | + sparse_policy_label: bool = True |
106 | 150 |
|
107 |
| - # total of training iterations |
| 151 | + info_total_it: str = "total_it defines the total number of training iterations. Usually this value is determined by" \ |
| 152 | + "dynamically based on the number of zip files and the number of samples in the validation file." |
108 | 153 | total_it: int = None
|
109 | 154 |
|
110 |
| - # adds a small mlp to infer the value loss from wdl and plys_to_end_output |
| 155 | + info_use_custom_architecture: str = "use_custom_architecture decides if a custom network architecture should be " \ |
| 156 | + "used, defined in the model_config.py file" |
| 157 | + use_custom_architecture: bool = False |
| 158 | + |
| 159 | + info_use_mlp_wdl_ply: str = "use_mlp_wdl_ply adds a small mlp to infer the value loss from wdl and plys_to_end" \ |
| 160 | + "_output" |
111 | 161 | use_mlp_wdl_ply: bool = False
|
112 |
| - # enables training with ply to end head |
113 |
| - use_plys_to_end: bool = False |
114 |
| - # enables training with a wdl head as intermediate target (mainly useful for environments with 3 outcomes) |
115 |
| - use_wdl: bool = False |
| 162 | + info_use_plys_to_end: str = "use_plys_to_end enables training with the plys to end head." |
| 163 | + use_plys_to_end: bool = True |
| 164 | + info_use_wdl: str = "use_wdl enables training with a wdl head as intermediate target (mainly useful for" \ |
| 165 | + " environments with three outcomes WIN, DRAW, LOSS)" |
| 166 | + use_wdl: bool = True |
116 | 167 |
|
117 |
| - # loads a previous checkpoint if the loss increased significantly |
| 168 | + info_use_spike_recovery: str = "use_spike_recovery loads a previous checkpoint if the loss increased significantly." |
118 | 169 | use_spike_recovery: bool = True
|
119 |
| - # weight the value loss a lot lower than the policy loss in order to prevent overfitting |
120 |
| - val_loss_factor: float = 0.5 # 0.01 |
121 |
| - # weight for the wdl loss |
122 |
| - wdl_loss_factor: float = 0.4 |
| 170 | + info_val_loss_factor: str = "val_loss_factor weights the value loss a lot lower than the policy loss in order to" \ |
| 171 | + " prevent overfitting" |
| 172 | + val_loss_factor: float = 0.01 |
| 173 | + info_policy_loss_factor: str = "policy_loss_factor defines the weighting factor for the policy loss." |
| 174 | + policy_loss_factor: float = 0.988 if use_plys_to_end else 0.99 |
123 | 175 |
|
124 |
| - # weight decay |
| 176 | + info_wdl_loss_factor: str = "wdl_loss_factor defines the weighting factor for the wdl-loss." |
| 177 | + wdl_loss_factor: float = 0.01 |
| 178 | + |
| 179 | + info_wd: str = "wd defines the weight decay value for regularization as a measure to prevent overfitting." |
125 | 180 | wd: float = 1e-4
|
126 | 181 |
|
127 | 182 |
|
| 183 | +def rl_train_config(): |
| 184 | + tc = TrainConfig() |
| 185 | + |
| 186 | + tc.export_grad_histograms = True |
| 187 | + tc.div_factor = 2 |
| 188 | + tc.batch_steps = 100 * tc.div_factor |
| 189 | + tc.batch_size = int(1024 / tc.div_factor) |
| 190 | + |
| 191 | + tc.max_lr = 0.1 / tc.div_factor |
| 192 | + tc.min_lr = 0.00001 / tc.div_factor |
| 193 | + |
| 194 | + tc.val_loss_factor = 0.499 if tc.use_plys_to_end else 0.5 |
| 195 | + tc.policy_loss_factor = 0.499 if tc.use_plys_to_end else 0.5 |
| 196 | + tc.plys_to_end_loss_factor = 0.002 |
| 197 | + tc.wdl_loss_factor = 0.499 if tc.use_plys_to_end else 0.5 |
| 198 | + |
| 199 | + tc.nb_training_epochs = 1 # define how many epochs the network will be trained |
| 200 | + tc.q_value_ratio = 0.15 |
| 201 | + tc.sparse_policy_label = False |
| 202 | + |
| 203 | + return tc |
| 204 | + |
| 205 | + |
128 | 206 | @dataclass
|
129 | 207 | class TrainObjects:
|
130 | 208 | """Defines training objects which must be set before the training"""
|
|
0 commit comments