Skip to content

Commit 19e37d0

Browse files
authored
Selfplay MoE (#223)
* Change exporter into vector<unique_ptr> exporters (one exporter object for each phase) * Add phase id specifier for export and make sure that only the appropriate phase exports the sample * Use max_samples_per_iteration() to end generation * Add check_for_moe(model_dir) * Add directories for MoE * Launch the training procedure multiple times for MoE * Add logging message * Start adding special cases for MoE * Added _move_all_files_wrapper() for cleaner code * Minor comment update * Fix compile errors * Add select_nn_index() to RawNetAgent * Fix condition * Add TODO, remove unnecessary '/' * Update compress_dataset * Update fileNameExport in selfplay.cpp * update get_current_model_tar_file to use model 0 for Moe * Reset generatedSamples in go() * add missing / * Skip game export for 0 samples * use phase0 in get_number_generated_files() * fix export of phases, update get_current_model_tar_file() * update get_current_model_tar_file() * update prepare_data_for_training() * update _move_all_files_wrapper() * remove unneeded / * Update get_current_model_tar_file() with phases * Update planes_train_dir and planes_val_dir in rl_loop.py * Simplify code and use reversed() for training * Add missing / in compress_dataset * Add _retrieve_end_idx(data) for MoE * Implement _include_data_from_replay_memory_wrapper() which handles MoE and non MoE cases * Implement staged learning v2.0, i.e. first train on full data and then each phase separately * Make use_moe_staged_learning auto detect * Skip "phaseNone" for counting phases * Fix condition for "phase_idx is None" * Skip "phaseNone" when loading models * Fix suffix * Create model_dir_archive for phaseNone * Add load checkpoint logging info * Set q_value_ratio to 0 for RL and add Exception for wdl is True conflict * Use middle phase for validation in staged learning on full data
1 parent 025793a commit 19e37d0

File tree

11 files changed

+309
-64
lines changed

11 files changed

+309
-64
lines changed

DeepCrazyhouse/configs/train_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class TrainConfig:
7171
info_model_type: str = "model_type defines the Model type that used during training (e.g. resnet, vit, risev2," \
7272
" risev3, alphavile, alphavile-tiny, alphavile-small, alphavile-normal, alphavile-large," \
7373
" NextViT)"
74-
model_type: str = "resnet"
74+
model_type: str = "risev3"
7575

7676
info_k_steps_initial: str = "k_steps_initial defines how many steps have been trained before (k_steps_initial != 0 if" \
7777
" you continue training from a checkpoint)" \
@@ -197,7 +197,7 @@ def rl_train_config():
197197
tc.wdl_loss_factor = 0.499 if tc.use_plys_to_end else 0.5
198198

199199
tc.nb_training_epochs = 1 # define how many epochs the network will be trained
200-
tc.q_value_ratio = 0.15
200+
tc.q_value_ratio = 0 # previously 0.15
201201
tc.sparse_policy_label = False
202202

203203
return tc

DeepCrazyhouse/src/training/trainer_agent_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def __init__(
5858
"""
5959
self.additional_loaders = additional_loaders
6060
self.tc = train_config
61+
if self.tc.use_wdl and self.tc.q_value_ratio != 0:
62+
raise Exception("q_value_ratio must be 0 for use_wdl = True.")
63+
6164
self.to = train_objects
6265
if self.to.metrics is None:
6366
self.to.metrics = {}

engine/src/agents/rawnetagent.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ RawNetAgent::RawNetAgent(const vector<unique_ptr<NeuralNetAPI>>& nets, const Pla
3535
{
3636
}
3737

38+
size_t RawNetAgent::select_nn_index() {
39+
if (nets.size() == 1) {
40+
return 0;
41+
}
42+
return phaseToNetsIndex.at(state->get_phase(numPhases, searchSettings->gamePhaseDefinition));
43+
}
44+
3845
void RawNetAgent::evaluate_board_state()
3946
{
4047
evalInfo->legalMoves = state->legal_actions();
@@ -55,7 +62,7 @@ void RawNetAgent::evaluate_board_state()
5562
return;
5663
}
5764
state->get_state_planes(true, inputPlanes, nets.front()->get_version());
58-
nets[phaseToNetsIndex.at(state->get_phase(numPhases, searchSettings->gamePhaseDefinition))]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
65+
nets[select_nn_index()]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
5966
state->set_auxiliary_outputs(auxiliaryOutputs);
6067

6168
evalInfo->policyProbSmall.resize(evalInfo->legalMoves.size());

engine/src/agents/rawnetagent.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ using namespace crazyara;
4141

4242
class RawNetAgent: public Agent
4343
{
44+
private:
45+
/**
46+
* @brief select_nn_index Returns the index according to phaseToNetsIndex.
47+
* If no phases is enabled, 0 will be returned.
48+
* @return phaseToNetsIndex index or 0
49+
*/
50+
size_t select_nn_index();
51+
4452
public:
4553
const SearchSettings* searchSettings;
4654

engine/src/rl/fileio.py

Lines changed: 181 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ def compress_zarr_dataset(data, file_path, compression='lz4', clevel=5, start_id
6565
return nan_detected
6666

6767

68+
def check_for_moe(model_dir: str):
69+
"""
70+
Extracts the number of phases from the given model directory.
71+
Returns true if mixture of experts is used.
72+
The second return argument is the number of phases.
73+
:param model_dir: Model directory, where either the model directly is stored or the number of phase directories.
74+
:return: is_moe: bool, number_phases: int or None
75+
"""
76+
number_phases = 0
77+
is_moe = False
78+
for entry in os.listdir(model_dir):
79+
if entry.startswith("phase") and "None" not in entry:
80+
number_phases += 1
81+
is_moe = True
82+
if not is_moe:
83+
number_phases = None
84+
return is_moe, number_phases
85+
86+
6887
class FileIO:
6988
"""
7089
Class to facilitate creation of directories, reading of file
@@ -100,6 +119,17 @@ def __init__(self, orig_binary_name: str, binary_dir: str, uci_variant: str):
100119

101120
self.timestamp_format = "%Y-%m-%d-%H-%M-%S"
102121

122+
self.is_moe, self.number_phases = check_for_moe(self.model_dir)
123+
124+
# Whether to use Staged learning v2.0 for MoE training,
125+
# i.e. first train on full data and then each phase separately
126+
self.use_moe_staged_learning = True if os.path.isdir(self.model_dir + "phaseNone") else False
127+
128+
if self.is_moe:
129+
logging.info(f"Mixture of experts detected with {self.number_phases} phases.")
130+
logging.info(f"Use MoE staged learning is {self.use_moe_staged_learning}.")
131+
else:
132+
logging.info("No mixture of experts detected.")
103133
self._create_directories()
104134

105135
# Adjust paths in main_config
@@ -124,14 +154,42 @@ def _create_directories(self):
124154
create_dir(self.model_dir_archive)
125155
create_dir(self.logs_dir_archive)
126156

127-
def _include_data_from_replay_memory(self, nb_files: int, fraction_for_selection: float):
157+
if self.is_moe:
158+
for directory in [self.export_dir_gen_data, self.train_dir, self.val_dir, self.train_dir_archive,
159+
self.val_dir_archive, self.model_contender_dir, self.model_dir_archive]:
160+
for phase_idx in range(self.number_phases):
161+
create_dir(directory + f"phase{phase_idx}")
162+
if self.use_moe_staged_learning:
163+
create_dir(self.model_contender_dir + "phaseNone")
164+
create_dir(self.model_dir_archive + "phaseNone")
165+
166+
def _include_data_from_replay_memory_wrapper(self, nb_files: int, fraction_for_selection: float):
128167
"""
168+
Wrapper for _include_data_from_replay_memory() which handles MoE and non MoE cases.
169+
:param nb_files: Number of files to include from replay memory into training
170+
:param fraction_for_selection: Proportion for selecting files from the replay memory
171+
"""
172+
173+
if not self.is_moe:
174+
self._include_data_from_replay_memory(self.train_dir_archive, self.train_dir, nb_files,
175+
fraction_for_selection)
176+
else:
177+
for phase_idx in range(self.number_phases):
178+
self._include_data_from_replay_memory(self.train_dir_archive + f"phase{phase_idx}/",
179+
self.train_dir + f"phase{phase_idx}/", nb_files,
180+
fraction_for_selection)
181+
182+
def _include_data_from_replay_memory(self, from_dir: str, to_dir: str, nb_files: int, fraction_for_selection: float):
183+
"""
184+
Moves data from the from_dir directory to the to_dir directory.
185+
:param from_dir: Usually train_dir_archive
186+
:param to_dir: Usually train_dir
129187
:param nb_files: Number of files to include from replay memory into training
130188
:param fraction_for_selection: Proportion for selecting files from the replay memory
131189
:return:
132190
"""
133191
# get all file/folder names ignoring hidden files
134-
folder_names = [folder_name for folder_name in os.listdir(self.train_dir_archive)
192+
folder_names = [folder_name for folder_name in os.listdir(from_dir)
135193
if not folder_name.startswith('.')]
136194

137195
if len(folder_names) < nb_files:
@@ -153,90 +211,160 @@ def _include_data_from_replay_memory(self, nb_files: int, fraction_for_selection
153211

154212
# move selected files into train dir
155213
for index in list(indices):
156-
os.rename(self.train_dir_archive + folder_names[index], self.train_dir + folder_names[index])
214+
os.rename(from_dir + folder_names[index], to_dir + folder_names[index])
157215

158216
def _move_generated_data_to_train_val(self):
159217
"""
160218
Moves the generated samples, games (pgn format) and the number how many games have been generated to the given
161219
training and validation directory
162220
:return:
163221
"""
164-
file_names = os.listdir(self.export_dir_gen_data)
222+
if not self.is_moe:
223+
file_names = os.listdir(self.export_dir_gen_data)
224+
225+
# move the last file into the validation directory
226+
os.rename(self.export_dir_gen_data + file_names[-1], self.val_dir + file_names[-1])
227+
228+
# move the rest into the training directory
229+
for file_name in file_names[:-1]:
230+
os.rename(self.export_dir_gen_data + file_name, self.train_dir + file_name)
231+
else:
232+
for phase_idx in range(self.number_phases):
233+
file_names = os.listdir(self.export_dir_gen_data + f"/phase{phase_idx}")
165234

166-
# move the last file into the validation directory
167-
os.rename(self.export_dir_gen_data + file_names[-1], self.val_dir + file_names[-1])
235+
# move the last file into the validation directory
236+
os.rename(self.export_dir_gen_data + f"/phase{phase_idx}/" + file_names[-1],
237+
self.val_dir + f"/phase{phase_idx}/" + file_names[-1])
168238

169-
# move the rest into the training directory
170-
for file_name in file_names[:-1]:
171-
os.rename(self.export_dir_gen_data + file_name, self.train_dir + file_name)
239+
# move the rest into the training directory
240+
for file_name in file_names[:-1]:
241+
os.rename(self.export_dir_gen_data + f"/phase{phase_idx}/" + file_name,
242+
self.train_dir + f"/phase{phase_idx}/" + file_name)
172243

173244
def _move_train_val_data_into_archive(self):
174245
"""
175246
Moves files from training, validation dir into archive directory
176247
:return:
177248
"""
178-
move_all_files(self.train_dir, self.train_dir_archive)
179-
move_all_files(self.val_dir, self.val_dir_archive)
249+
self._move_all_files_wrapper(self.train_dir, self.train_dir_archive)
250+
self._move_all_files_wrapper(self.val_dir, self.val_dir_archive)
180251

181252
def _remove_files_in_weight_dir(self):
182253
"""
183254
Removes all files in the weight directory.
184255
:return:
185256
"""
186-
file_list = glob.glob(os.path.join(self.weight_dir, "model-*"))
187-
for file in file_list:
188-
os.remove(file)
257+
if not self.is_moe:
258+
file_list = glob.glob(os.path.join(self.weight_dir, "model-*"))
259+
for file in file_list:
260+
os.remove(file)
261+
else:
262+
for phase_idx in range(self.number_phases):
263+
file_list = glob.glob(os.path.join(self.weight_dir, f"phase{phase_idx}/model-*"))
264+
for file in file_list:
265+
os.remove(file)
189266

190-
def compress_dataset(self, device_name: str):
267+
def _compress_single_dataset(self, phase: str, device_name: str):
191268
"""
192-
Loads the uncompressed data file, selects all sample until the index specified in "startIdx.txt",
193-
compresses it and exports it.
194-
:param device_name: The currently active device name (context_device-id)
195-
:return:
269+
Loads a single uncompressed data file, selects all samples, compresses it and exports it.
270+
:param phase: Phase to use, e.g. "phase0/", "phase1". Is empty string for no phase ("").
271+
:return: export_dir: str
196272
"""
197-
data = zarr.load(self.binary_dir + "data_" + device_name + ".zarr")
273+
data = zarr.load(self.binary_dir + phase + "data_" + device_name + ".zarr")
198274

199-
export_dir, time_stamp = self.create_export_dir(device_name)
275+
export_dir, time_stamp = self.create_export_dir(phase, device_name)
200276
zarr_path = export_dir + time_stamp + ".zip"
201-
nan_detected = compress_zarr_dataset(data, zarr_path, start_idx=0)
277+
278+
end_idx = self._retrieve_end_idx(data)
279+
280+
nan_detected = compress_zarr_dataset(data, zarr_path, start_idx=0, end_idx=end_idx)
202281
if nan_detected is True:
203282
logging.error("NaN value detected in file %s.zip" % time_stamp)
204283
new_export_dir = self.binary_dir + time_stamp
205284
os.rename(export_dir, new_export_dir)
206285
export_dir = new_export_dir
207-
self.move_game_data_to_export_dir(export_dir, device_name)
208286

209-
def create_export_dir(self, device_name: str) -> (str, str):
287+
return export_dir
288+
289+
def _retrieve_end_idx(self, data):
290+
"""
291+
Checks the y_policy sum in the data for is_moe is False and
292+
returns the first occurence of only 0s.
293+
An end_idx of 0 means the whole dataset will be used
294+
:param data: Zarr data object
295+
:return: end_idx
296+
"""
297+
if self.is_moe is False:
298+
return 0
299+
300+
sum_y_policy = data['y_policy'].sum(axis=1)
301+
potential_end_idx = sum_y_policy.argmin()
302+
if sum_y_policy[potential_end_idx] == 0:
303+
return potential_end_idx
304+
return 0
305+
306+
def compress_dataset(self, device_name: str):
307+
"""
308+
Calls _compress_single_dataset() for each phase or a single time for no phases.
309+
Also moves the game data to export directory.
310+
:param device_name: The currently active device name (context_device-id)
311+
:return:
312+
"""
313+
if self.is_moe:
314+
for phase_idx in range(self.number_phases):
315+
export_dir = self._compress_single_dataset(f"phase{phase_idx}/", device_name)
316+
if phase_idx == 0:
317+
self.move_game_data_to_export_dir(export_dir, device_name)
318+
else:
319+
export_dir = self._compress_single_dataset("", device_name)
320+
self.move_game_data_to_export_dir(export_dir, device_name)
321+
322+
def create_export_dir(self, phase: str, device_name: str) -> (str, str):
210323
"""
211324
Create a directory in the 'export_dir_gen_data' path,
212325
where the name consists of the current date, time and device ID.
326+
:param phase: Phase to use, e.g. "phase0/", "phase1". Is empty string for no phase ("").
213327
:param device_name: The currently active device name (context_device-id)
214328
:return: Path of the created directory; Time stamp used while creating
215329
"""
216330
# include current timestamp in dataset export file
217331
time_stamp = datetime.datetime.fromtimestamp(time.time()).strftime(self.timestamp_format)
218-
time_stamp_dir = f'{self.export_dir_gen_data}{time_stamp}-{device_name}/'
332+
time_stamp_dir = f'{self.export_dir_gen_data}{phase}{time_stamp}-{device_name}/'
219333
# create a directory of the current time_stamp
220334
if not os.path.exists(time_stamp_dir):
221335
os.makedirs(time_stamp_dir)
222336

223337
return time_stamp_dir, time_stamp
224338

225-
def get_current_model_tar_file(self) -> str:
339+
def get_current_model_tar_file(self, phase=None, base_dir=None) -> str:
226340
"""
341+
:param phase: Phase to use. Should be "" if no MoE is used and otherwise e.g. "phase2".
342+
:param base_dir: Should be self.model_dir in the normal case
343+
For None default "phase0" or "" will be used.
227344
Return the filename of the current active model weight (.tar) file for pytorch
228345
"""
229-
model_params = glob.glob(self.model_dir + "/*.tar")
346+
if phase is None:
347+
if self.is_moe:
348+
phase = "phase0"
349+
else:
350+
phase = ""
351+
if base_dir is None:
352+
base_dir = self.model_dir
353+
model_params = glob.glob(base_dir + phase + "/*.tar")
230354
if len(model_params) == 0:
231355
raise FileNotFoundError(f'No model file found in {self.model_dir}')
232356
return model_params[0]
233357

234358
def get_number_generated_files(self) -> int:
235359
"""
236360
Returns the amount of file that have been generated since the last training run.
237-
:return:
361+
:return: nb_training_files: int
238362
"""
239-
return len(glob.glob(self.export_dir_gen_data + "**/*.zip"))
363+
if self.is_moe:
364+
phase = "phase0/"
365+
else:
366+
phase = ""
367+
return len(glob.glob(self.export_dir_gen_data + phase + "**/*.zip"))
240368

241369
def move_game_data_to_export_dir(self, export_dir: str, device_name: str):
242370
"""
@@ -271,12 +399,12 @@ def prepare_data_for_training(self, rm_nb_files: int, rm_fraction_for_selection:
271399
if did_contender_win:
272400
self._move_train_val_data_into_archive()
273401
# move last contender into archive
274-
move_all_files(self.model_contender_dir, self.model_dir_archive)
402+
self._move_all_files_wrapper(self.model_contender_dir, self.model_dir_archive)
275403

276404
self._move_generated_data_to_train_val()
277405
# We don’t need them anymore; the last model from last training has already been saved
278406
self._remove_files_in_weight_dir()
279-
self._include_data_from_replay_memory(rm_nb_files, rm_fraction_for_selection)
407+
self._include_data_from_replay_memory_wrapper(rm_nb_files, rm_fraction_for_selection)
280408

281409
def remove_intermediate_weight_files(self):
282410
"""
@@ -288,10 +416,29 @@ def remove_intermediate_weight_files(self):
288416
for f in files:
289417
os.remove(f)
290418

419+
def _move_all_files_wrapper(self, from_dir, to_dir):
420+
"""
421+
Wrapper function for move_all_files(from_dir, to_dir).
422+
In case of self.is_moe, all phases directories are moved as well.
423+
:param from_dir: Origin directory where the files are located
424+
:param to_dir: Destination directory where all files (including subdirectories directories) will be moved
425+
:return:
426+
"""
427+
if not self.is_moe:
428+
move_all_files(from_dir, to_dir)
429+
else:
430+
for phase_idx in range(self.number_phases):
431+
move_all_files(from_dir + f"phase{phase_idx}/", to_dir + f"phase{phase_idx}/")
432+
433+
if self.use_moe_staged_learning:
434+
from_dir_final = from_dir + "phaseNone/"
435+
to_dir_final = to_dir + "phaseNone/"
436+
if os.path.isdir(from_dir_final) and os.path.isdir(to_dir_final):
437+
move_all_files(from_dir_final, to_dir_final)
438+
291439
def replace_current_model_with_contender(self):
292440
"""
293441
Moves the previous model into archive directory and the model-contender into the model directory
294442
"""
295-
move_all_files(self.model_dir, self.model_dir_archive)
296-
move_all_files(self.model_contender_dir, self.model_dir)
297-
443+
self._move_all_files_wrapper(self.model_dir, self.model_dir_archive)
444+
self._move_all_files_wrapper(self.model_contender_dir, self.model_dir)

0 commit comments

Comments
 (0)