Skip to content

Commit b4216ea

Browse files
committed
Add select_nn_index() for phase selection (#216)
- return 0 if no phases is enabled Set Game_Phase_Definition default to "lichess"
1 parent 36582c6 commit b4216ea

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

engine/src/searchthread.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -379,27 +379,35 @@ void SearchThread::create_mini_batch()
379379
}
380380
}
381381

382+
size_t SearchThread::select_nn_index()
383+
{
384+
if (nets.size() == 1) {
385+
return 0;
386+
}
387+
// determine majority class in current batch
388+
using pair_type = decltype(phaseCountMap)::value_type;
389+
auto pr = std::max_element
390+
(
391+
std::begin(phaseCountMap), std::end(phaseCountMap),
392+
[](const pair_type& p1, const pair_type& p2) {
393+
return p1.second < p2.second;
394+
}
395+
);
396+
397+
GamePhase majorityPhase = pr->first;
398+
399+
phaseCountMap.clear();
400+
return phaseToNetsIndex.at(majorityPhase);
401+
}
402+
382403
void SearchThread::thread_iteration()
383404
{
384405
create_mini_batch();
385406
#ifndef SEARCH_UCT
386407
if (newNodes->size() != 0) {
387-
388-
// determine majority class in current batch
389-
using pair_type = decltype(phaseCountMap)::value_type;
390-
auto pr = std::max_element
391-
(
392-
std::begin(phaseCountMap), std::end(phaseCountMap),
393-
[](const pair_type& p1, const pair_type& p2) {
394-
return p1.second < p2.second;
395-
}
396-
);
397-
398-
GamePhase majorityPhase = pr->first;
399408

400-
phaseCountMap.clear();
401409
// query the network that corresponds to the majority phase
402-
nets[phaseToNetsIndex.at(majorityPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
410+
nets[select_nn_index()]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
403411
set_nn_results_to_child_nodes();
404412
}
405413
#endif

engine/src/searchthread.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ class SearchThread : public NeuralNetAPIUser
196196
* @return Q-Value converted to double
197197
*/
198198
double get_current_transposition_q_value(const Node* currentNode, ChildIdx childIdx, uint_fast32_t transposVisits);
199+
200+
/**
201+
* @brief select_nn_index Returns the index according to the majority phase in the current batch.
202+
* If no phases is enabled, 0 will be returned.
203+
* @return Majority phase index or 0
204+
*/
205+
size_t select_nn_index();
199206
};
200207

201208
void run_search_thread(SearchThread *t);

engine/src/uci/optionsuci.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void OptionsUCI::init(OptionsMap &o)
193193
o["Use_Raw_Network"] << Option(false);
194194
o["Virtual_Style"] << Option("virtual_mix", { "virtual_loss", "virtual_visit", "virtual_offset", "virtual_mix" });
195195
o["Virtual_Mix_Threshold"] << Option(1000, 1, 99999999);
196-
o["Game_Phase_Definition"] << Option("movecount", { "lichess", "movecount"});
196+
o["Game_Phase_Definition"] << Option("lichess", { "lichess", "movecount"});
197197
// additional UCI-Options for RL only
198198
#ifdef USE_RL
199199
o["Centi_Node_Random_Factor"] << Option(10, 0, 100);

0 commit comments

Comments
 (0)