From 9d8b426f03f74d1aa738c8593206e719ac03f0f6 Mon Sep 17 00:00:00 2001 From: yffbit Date: Tue, 16 Apr 2024 19:28:04 +0800 Subject: [PATCH 01/19] add slice_cfr --- .gitignore | 1 + CMakeLists.txt | 106 ++++ TexasSolverGui.pro | 2 +- benchmark/texassolver.txt | 41 ++ imgs/texassolver_logo.rc | 1 + include/runtime/PokerSolver.h | 6 +- include/solver/slice_cfr.h | 181 +++++++ include/tools/CommandLineTool.h | 1 + include/tools/utils.h | 19 + src/compairer/Dic5Compairer.cpp | 2 + src/console.cpp | 2 +- src/runtime/PokerSolver.cpp | 59 +-- src/runtime/qsolverjob.cpp | 7 +- src/solver/BestResponse.cpp | 2 +- src/solver/PCfrSolver.cpp | 2 +- src/solver/slice_cfr.cpp | 867 ++++++++++++++++++++++++++++++++ src/tools/CommandLineTool.cpp | 21 +- 17 files changed, 1283 insertions(+), 37 deletions(-) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 benchmark/texassolver.txt create mode 100644 imgs/texassolver_logo.rc create mode 100644 include/solver/slice_cfr.h create mode 100644 src/solver/slice_cfr.cpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..84c048a --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/build/ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..11f5d36 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,106 @@ +cmake_minimum_required(VERSION 3.20) + +# project(TexasSolver LANGUAGES CXX CUDA) +project(TexasSolver LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +# set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# set(CMAKE_AUTOMOC ON) +set(CMAKE_AUTORCC ON) +set(CMAKE_AUTOUIC ON) + +set(CMAKE_CUDA_STANDARD 17) +# set(CMAKE_CUDA_STANDARD_REQUIRED ON) +message("${CMAKE_MINOR_VERSION}") +if(${CMAKE_MINOR_VERSION} GREATER_EQUAL 24) + # set(CMAKE_CUDA_ARCHITECTURES all) + set(CMAKE_CUDA_ARCHITECTURES all-major) + # set(CMAKE_CUDA_ARCHITECTURES native) +else() + set(CMAKE_CUDA_ARCHITECTURES OFF) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=all-major") +endif() +message("CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}") + +if((DEFINED CMAKE_BUILD_TYPE) AND (CMAKE_BUILD_TYPE STREQUAL Debug)) + set(CMAKE_CUDA_FLAGS "-g -G ${CMAKE_CUDA_FLAGS}") +endif() + +set(CMAKE_INCLUDE_CURRENT_DIR ON) +include_directories(include) + +find_package(OpenMP REQUIRED) +message("OpenMP_CXX_FLAGS=${OpenMP_CXX_FLAGS}") + +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler ${OpenMP_CXX_FLAGS}") +message("CMAKE_CUDA_FLAGS=${CMAKE_CUDA_FLAGS}") + +find_package(QT NAMES Qt6 Qt5 REQUIRED COMPONENTS Widgets) +set(QT_MAJOR Qt${QT_VERSION_MAJOR}) +message("QT_MAJOR=${QT_MAJOR}") +find_package(${QT_MAJOR} REQUIRED COMPONENTS Core Widgets LinguistTools) + +file(GLOB_RECURSE SRC src/*.cpp) +file(GLOB GUI_SRC *.cpp src/ui/*.cpp src/runtime/qsolverjob.cpp) +file(GLOB API_SRC src/api.cpp) +file(GLOB EXE_SRC src/console.cpp) +list(REMOVE_ITEM SRC ${GUI_SRC} ${EXE_SRC} ${API_SRC}) +# message("SRC=${SRC}") +# message("GUI_SRC=${GUI_SRC}") +# message("API_SRC=${API_SRC}") +# message("EXE_SRC=${EXE_SRC}") +# file(GLOB_RECURSE CUDA_SRC src/*.cu) +message("CUDA_SRC=${CUDA_SRC}") + +set(BASE_LIB TexasSolver) +add_library(${BASE_LIB} ${SRC} ${CUDA_SRC}) +target_link_libraries(${BASE_LIB} PUBLIC ${QT_MAJOR}::Core OpenMP::OpenMP_CXX) +# set_target_properties(${BASE_LIB} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + +set(API_TARGET api) +add_library(${API_TARGET} SHARED ${API_SRC}) +target_link_libraries(${API_TARGET} PUBLIC ${BASE_LIB}) + +set(EXE console_solver) +add_executable(${EXE} ${EXE_SRC}) +target_link_libraries(${EXE} PRIVATE ${BASE_LIB}) +# target_link_options(${EXE} PUBLIC "/NODEFAULTLIB:LIBCMT") + +file(GLOB FORMS *.ui) +file(GLOB RESOURCES *.qrc) +file(GLOB TS_FILES *.ts) +file(GLOB QM_FILES *.qm) +# message("FORMS=${FORMS}") +# message("RESOURCES=${RESOURCES}") +# message("TS_FILES=${TS_FILES}") +# message("QM_FILES=${QM_FILES}") + +SET(ICON_NAME texassolver_logo) +if(WIN32) + file(GLOB ICON_FILE imgs/${ICON_NAME}.rc) +elseif(APPLE) + set(MACOSX_BUNDLE_ICON_FILE ${ICON_NAME}.icns) + file(GLOB ICON_FILE imgs/${ICON_NAME}.icns) + set_source_files_properties(${ICON_FILE} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources") +endif() +# message("ICON_FILE=${ICON_FILE}") + +# set(CMAKE_AUTOMOC ON) doesn't work +# Q_OBJECT header +file(GLOB HEADERS *.h include/ui/*.h include/runtime/qsolverjob.h) +# message("HEADERS=${HEADERS}") +if(${QT_VERSION_MAJOR} GREATER_EQUAL 6) + qt6_wrap_cpp(GUI_SRC ${HEADERS}) +else() + qt5_wrap_cpp(GUI_SRC ${HEADERS}) +endif() +# message("GUI_SRC=${GUI_SRC}") + +set(GUI TexasSolverGui) +add_executable(${GUI} ${GUI_SRC} ${RESOURCES} ${FORMS} ${ICON_FILE}) +target_link_libraries(${GUI} PRIVATE ${QT_MAJOR}::Widgets ${BASE_LIB}) +set_target_properties(${GUI} PROPERTIES + WIN32_EXECUTABLE ON + MACOSX_BUNDLE ON +) diff --git a/TexasSolverGui.pro b/TexasSolverGui.pro index c932ba6..2d6e4d8 100644 --- a/TexasSolverGui.pro +++ b/TexasSolverGui.pro @@ -64,7 +64,7 @@ SOURCES += \ mainwindow.cpp \ src/Deck.cpp \ src/Card.cpp \ - src/console.cpp \ + # src/console.cpp \ src/GameTree.cpp \ src/library.cpp \ src/compairer/Dic5Compairer.cpp \ diff --git a/benchmark/texassolver.txt b/benchmark/texassolver.txt new file mode 100644 index 0000000..9d37c5f --- /dev/null +++ b/benchmark/texassolver.txt @@ -0,0 +1,41 @@ +set_pot 10 +set_effective_stack 95 +#set_board Qs,Jh,2h,4d +#set_range_oop AA,KK,QQ,JJ +#set_range_ip QQ:0.5,JJ:0.75 +set_board Qs,Jh,2h +set_range_oop AA,KK,QQ,JJ,TT,99:0.75,88:0.75,77:0.5,66:0.25,55:0.25,AK,AQs,AQo:0.75,AJs,AJo:0.5,ATs:0.75,A6s:0.25,A5s:0.75,A4s:0.75,A3s:0.5,A2s:0.5,KQs,KQo:0.5,KJs,KTs:0.75,K5s:0.25,K4s:0.25,QJs:0.75,QTs:0.75,Q9s:0.5,JTs:0.75,J9s:0.75,J8s:0.75,T9s:0.75,T8s:0.75,T7s:0.75,98s:0.75,97s:0.75,96s:0.5,87s:0.75,86s:0.5,85s:0.5,76s:0.75,75s:0.5,65s:0.75,64s:0.5,54s:0.75,53s:0.5,43s:0.5 +set_range_ip QQ:0.5,JJ:0.75,TT,99,88,77,66,55,44,33,22,AKo:0.25,AQs,AQo:0.75,AJs,AJo:0.75,ATs,ATo:0.75,A9s,A8s,A7s,A6s,A5s,A4s,A3s,A2s,KQ,KJ,KTs,KTo:0.5,K9s,K8s,K7s,K6s,K5s,K4s:0.5,K3s:0.5,K2s:0.5,QJ,QTs,Q9s,Q8s,Q7s,JTs,JTo:0.5,J9s,J8s,T9s,T8s,T7s,98s,97s,96s,87s,86s,76s,75s,65s,64s,54s,53s,43s +set_bet_sizes oop,flop,bet,100 +set_bet_sizes oop,flop,raise,50 +set_bet_sizes oop,flop,allin +set_bet_sizes ip,flop,bet,100 +set_bet_sizes ip,flop,raise,50 +set_bet_sizes ip,flop,allin +set_bet_sizes oop,turn,bet,100 +set_bet_sizes oop,turn,donk,100 +set_bet_sizes oop,turn,raise,50 +set_bet_sizes oop,turn,allin +set_bet_sizes ip,turn,bet,100 +set_bet_sizes ip,turn,raise,50 +set_bet_sizes oop,river,bet,100 +set_bet_sizes oop,river,donk,100 +set_bet_sizes oop,river,raise,50 +set_bet_sizes oop,river,allin +set_bet_sizes ip,river,bet,100 +set_bet_sizes ip,river,raise,50 +set_bet_sizes ip,river,allin +set_allin_threshold 1.0 +set_raise_limit 2 +build_tree +estimate_tree_memory +set_thread_num 6 +#set_thread_num 81920 +set_slice_cfr 1 +set_accuracy 0.3 +set_max_iteration 2000 +set_print_interval 10 +#set_use_isomorphism 1 +start_solve +set_dump_rounds 1 +#dump_result output_result.json diff --git a/imgs/texassolver_logo.rc b/imgs/texassolver_logo.rc new file mode 100644 index 0000000..e0626ca --- /dev/null +++ b/imgs/texassolver_logo.rc @@ -0,0 +1 @@ +IDI_ICON1 ICON "texassolver_logo.ico" \ No newline at end of file diff --git a/include/runtime/PokerSolver.h b/include/runtime/PokerSolver.h index 705919b..b454f65 100644 --- a/include/runtime/PokerSolver.h +++ b/include/runtime/PokerSolver.h @@ -12,6 +12,7 @@ #include "include/solver/CfrSolver.h" #include "include/solver/PCfrSolver.h" #include "include/library.h" +#include "solver/slice_cfr.h" #include #include using namespace std; @@ -44,10 +45,11 @@ class PokerSolver { float accuracy, bool use_isomorphism, int use_halffloats, - int threads + int threads, + bool slice_cfr = false ); void stop(); - long long estimate_tree_memory(QString range1,QString range2,QString board); + long long estimate_tree_memory(string& p1_range, string& p2_range, string& board); vector player1Range; vector player2Range; void dump_strategy(QString dump_file,int dump_rounds); diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h new file mode 100644 index 0000000..a79df72 --- /dev/null +++ b/include/solver/slice_cfr.h @@ -0,0 +1,181 @@ +#ifndef _SLICE_CFR_H_ +#define _SLICE_CFR_H_ + +#include +#include +#include +#include +#include "nodes/GameTreeNode.h" +#include "solver/PCfrSolver.h" +#include +#include + +using std::vector; +using std::unordered_set; +using std::unordered_map; +using std::dynamic_pointer_cast; +using std::mutex; + +#define N_CARD 52 +#define N_PLAYER 2 +#define P0 0 +#define P1 1 +#define CHANCE_PLAYER N_PLAYER + +#define N_ROUND 4 +#define PREFLOP_ROUND 0 +#define FLOP_ROUND 1 +#define TURN_ROUND 2 +#define RIVER_ROUND 3 + +#define FOLD_TYPE 0 +#define SHOWDOWN_TYPE 1 +#define N_LEAF_TYPE 2 + +#define N_TYPE 5 +#define N_TASK_SIZE 5 +#define two_card_hash(card1, card2) ((1LL<<(card1)) | (1LL<<(card2))) +#define tril_idx(r, c) (((r)*((r)-1)>>1)+(c)) // r>c>=0 + +#define get_size(n_act, n_hand) (((n_act) * 4 + 1) * (n_hand)) +#define cfv_offset(n_hand, act_idx) ((n_hand) * (act_idx)) +#define reach_prob_offset(n_act, n_hand, act_idx) (((n_act) * 3 + (act_idx)) * (n_hand)) +#define reach_prob_to_cfv(n_act, n_hand) ((n_act) * (n_hand) * 3) + +struct Node { + int n_act = 0;// 动作数 + int parent_offset = -1;// 本节点对应的父节点数据reach_prob的偏移量 + float *parent_cfv = nullptr; + mutex *mtx = nullptr; + float *data = nullptr;// cfv,regret_sum,strategy_sum,reach_prob,sum +}; +struct LeafNode { + float *reach_prob[N_PLAYER] = {nullptr,nullptr}; + size_t info = 0; +}; +struct PreLeafNode { + PreLeafNode(float *cfv):cfv(cfv) {} + float *cfv = nullptr; + vector leaf_node_idx; +}; +struct DFSNode { + DFSNode(int player, int n_act, int parent_act, int info, int parent_dfs_idx, int parent_p0_act, int parent_p0_idx, int parent_p1_act, int parent_p1_idx) + :player(player), n_act(n_act), parent_act(parent_act), info(info), parent_dfs_idx(parent_dfs_idx) + , parent_p0_act(parent_p0_act), parent_p0_idx(parent_p0_idx), parent_p1_act(parent_p1_act), parent_p1_idx(parent_p1_idx) {} + int player = -1;// 活动玩家(叶子节点时为父节点玩家) + int n_act = 0;// 动作数 + int parent_act = -1;// 本节点对应的父节点动作索引 + int info = 0; + int parent_dfs_idx = -1; + int parent_p0_act = -1; + int parent_p0_idx = -1; + int parent_p1_act = -1; + int parent_p1_idx = -1; +}; + +struct StrengthData { + StrengthData(int size, const RiverCombs *p):size(size), data(p) {} + int size = 0; + const RiverCombs *data = nullptr; +}; + +class SliceCFR : public Solver { +public: + SliceCFR( + shared_ptr tree, + vector &range1, + vector &range2, + vector &initial_board, + shared_ptr compairer, + Deck &deck, + int train_step, + int print_interval, + float accuracy, + int n_thread + ); + ~SliceCFR(); + size_t estimate_tree_size(); + void train(); + vector exploitability(); + void stop(); + json dumps(bool with_status, int depth); + vector>> get_strategy(shared_ptr node, vector cards); + vector>> get_evs(shared_ptr node, vector cards); +private: + atomic_bool stop_flag = false; + bool init_succ = false; + int n_thread = 0; + int thread_per_block = 32; + int steps = 0, interval = 0, n_card = N_CARD, min_card = 0; + int init_round = 0; + int dfs_idx = 0;// 先序遍历 + size_t init_board = 0; + int hand_size[N_PLAYER]; + float norm = 1;// 根节点概率归一化系数 + float tol = 0.01;// exploitability容忍度 + float alpha = 1.5, beta = 0, gamma = 2; + float pos_coef = 0, neg_coef = 0, coef = 0; + RiverRangeManager rrm; + vector hand_card;// p0_card1,p0_card2,p1_card1,p1_card2,相对于min_card的偏移量 + int *hand_card_ptr[N_PLAYER] {nullptr,nullptr}; + vector hand_hash; + size_t *hand_hash_ptr[N_PLAYER] {nullptr,nullptr}; + vector poss_card; + int chance_branch[N_ROUND]; + int chance_den[N_ROUND]; + vector same_hand_idx; + int *same_hand_ptr[N_PLAYER] {nullptr,nullptr}; + vector> ranges; + vector dfs_node; + vector dfs_idx_map;// dfs遍历的每个节点在cuda中的索引 + int node_cnt[N_TYPE]; + int n_leaf_node = 0; + int n_player_node = 0; + vector> leaf_node_dfs; + vector chance_node; + vector> ev; + float *ev_ptr = nullptr; + vector>> slice; + vector> slice_offset; + vector root_cfv, root_prob;// P0_cfv,P1_cfv,P0_prob,P1_prob + float *root_prob_ptr[N_PLAYER] {nullptr,nullptr}; + float *root_cfv_ptr[N_PLAYER] {nullptr,nullptr}; + shared_ptr tree = nullptr; + Deck& deck; + void init_hand_card(vector &range1, vector &range2); + void init_hand_card(vector &range, vector &cards, vector &prob, size_t board, vector &out); + void init_same_hand_idx(); + void init_min_card(); + size_t init_memory(shared_ptr compairer); + size_t init_player_node(); + size_t init_leaf_node(); + void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset, mutex *&mtx); + void normalization(); + size_t init_strength_table(shared_ptr compairer); + void dfs(shared_ptr node, int parent_act=-1, int parent_dfs_idx=-1, int parent_p0_act=-1, int parent_p0_idx=-1, int parent_p1_act=-1, int parent_p1_idx=-1, int cnt0=0, int cnt1=0, int info=0); + void init_poss_card(Deck& deck, size_t board); + void step(int iter, int player, bool best_cfv=false); + void leaf_cfv(int player); + void fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, size_t board); + void sd_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, int idx); + void append_node_idx(int p_idx, int act_idx, int player, int cpu_node_idx); + vector> pre_leaf_node_map;// [dfs_idx,act_idx] + vector> pre_leaf_node;// [player,idx] + vector> root_child_idx; + vector leaf_node; + vector player_node; + Node *player_node_ptr = nullptr; + int sd_offset = 0; + vector cpu_cfv; + vector mtx; + vector> mtx_map; + int mtx_idx = N_PLAYER; + vector> strength; + size_t _estimate_tree_size(shared_ptr node); + void _reach_prob(int player, bool best_cfv=false); + void _rm(int player, bool best_cfv=false); + void clear_data(int player); + void clear_root_cfv(); +}; + +#endif // _SLICE_CFR_H_ diff --git a/include/tools/CommandLineTool.h b/include/tools/CommandLineTool.h index 3b078fa..a6482a5 100644 --- a/include/tools/CommandLineTool.h +++ b/include/tools/CommandLineTool.h @@ -41,6 +41,7 @@ class CommandLineTool{ int max_iteration=100; int use_isomorphism=0; int print_interval=10; + bool slice_cfr = false; int dump_rounds = 1; shared_ptr gtbs; }; diff --git a/include/tools/utils.h b/include/tools/utils.h index c4f9ef9..6009be1 100644 --- a/include/tools/utils.h +++ b/include/tools/utils.h @@ -64,4 +64,23 @@ void exchange_color(vector& value,vector range,int rank1,int ra //throw runtime_error("exiting...here..."); } +#include +class Timer { +public: + Timer() { + reset(); + } + void reset() { + start = std::chrono::steady_clock::now(); + } + int64_t ms() { + return std::chrono::duration_cast(std::chrono::steady_clock::now()-start).count(); + } + int64_t us() { + return std::chrono::duration_cast(std::chrono::steady_clock::now()-start).count(); + } +private: + std::chrono::steady_clock::time_point start {}; +}; + #endif //BINDSOLVER_UTILS_H diff --git a/src/compairer/Dic5Compairer.cpp b/src/compairer/Dic5Compairer.cpp index 8410ea8..e1133ff 100644 --- a/src/compairer/Dic5Compairer.cpp +++ b/src/compairer/Dic5Compairer.cpp @@ -9,7 +9,9 @@ #include #include #include "time.h" +#ifndef _MSC_VER #include "unistd.h" +#endif #define SUIT_0_MASK 0x1111111111111 #define SUIT_1_MASK 0x2222222222222 diff --git a/src/console.cpp b/src/console.cpp index d83e852..bd55432 100644 --- a/src/console.cpp +++ b/src/console.cpp @@ -4,7 +4,7 @@ #include "include/tools/CommandLineTool.h" #include "include/tools/argparse.hpp" -int main_backup(int argc,const char **argv) { +int main(int argc,const char **argv) { ArgumentParser parser; parser.addArgument("-i", "--input_file", 1, true); diff --git a/src/runtime/PokerSolver.cpp b/src/runtime/PokerSolver.cpp index bff1619..1cf551f 100644 --- a/src/runtime/PokerSolver.cpp +++ b/src/runtime/PokerSolver.cpp @@ -67,29 +67,26 @@ void PokerSolver::stop(){ } } -long long PokerSolver::estimate_tree_memory(QString range1,QString range2,QString board){ +long long PokerSolver::estimate_tree_memory(string &p1_range, string &p2_range, string &board){ if(this->game_tree == nullptr){ qDebug().noquote() << QObject::tr("Please buld tree first."); return 0; } else{ - string player1RangeStr = range1.toStdString(); - string player2RangeStr = range2.toStdString(); - - vector board_str_arr = string_split(board.toStdString(),','); + vector board_str_arr = string_split(board,','); vector initialBoard; for(string one_board_str:board_str_arr){ initialBoard.push_back(Card::strCard2int(one_board_str)); } - vector range1 = PrivateRangeConverter::rangeStr2Cards(player1RangeStr,initialBoard); - vector range2 = PrivateRangeConverter::rangeStr2Cards(player2RangeStr,initialBoard); + vector range1 = PrivateRangeConverter::rangeStr2Cards(p1_range,initialBoard); + vector range2 = PrivateRangeConverter::rangeStr2Cards(p2_range,initialBoard); return this->game_tree->estimate_tree_memory(this->deck.getCards().size() - initialBoard.size(),range1.size(),range2.size()); } } void PokerSolver::train(string p1_range, string p2_range, string boards, string log_file, int iteration_number, - int print_interval, string algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads) { + int print_interval, string algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads, bool slice_cfr) { string player1RangeStr = p1_range; string player2RangeStr = p2_range; @@ -107,26 +104,32 @@ void PokerSolver::train(string p1_range, string p2_range, string boards, string this->player2Range = noDuplicateRange(range2,initial_board_long); string logfile_name = log_file; - this->solver = make_shared( - game_tree - , range1 - , range2 - , initialBoard - , compairer - , deck - , iteration_number - , false - , print_interval - , logfile_name - , algorithm - , Solver::MonteCarolAlg::NONE - , warmup - , accuracy - , use_isomorphism - , use_halffloats - , threads - ); - this->solver->train(); + if(solver) solver.reset();// 释放内存 + if(slice_cfr) { + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads); + } + else { + solver = make_shared( + game_tree + , range1 + , range2 + , initialBoard + , compairer + , deck + , iteration_number + , false + , print_interval + , logfile_name + , algorithm + , Solver::MonteCarolAlg::NONE + , warmup + , accuracy + , use_isomorphism + , use_halffloats + , threads + ); + } + solver->train(); } void PokerSolver::dump_strategy(QString dump_file,int dump_rounds) { diff --git a/src/runtime/qsolverjob.cpp b/src/runtime/qsolverjob.cpp index 27a663f..4c71221 100644 --- a/src/runtime/qsolverjob.cpp +++ b/src/runtime/qsolverjob.cpp @@ -135,10 +135,13 @@ void QSolverJob::solving(){ long long QSolverJob::estimate_tree_memory(QString range1,QString range2,QString board){ qDebug().noquote() << tr("Estimating tree memory..");//.toStdString() << endl; + string p1_range = range1.toStdString(); + string p2_range = range2.toStdString(); + string board_str = board.toStdString(); if(this->mode == Mode::HOLDEM){ - return ps_holdem.estimate_tree_memory(range1,range2,board); + return ps_holdem.estimate_tree_memory(p1_range, p1_range, board_str); }else if(this->mode == Mode::SHORTDECK){ - return ps_shortdeck.estimate_tree_memory(range1,range2,board); + return ps_shortdeck.estimate_tree_memory(p1_range, p1_range, board_str); } return 0; } diff --git a/src/solver/BestResponse.cpp b/src/solver/BestResponse.cpp index 90d174b..0c46b36 100644 --- a/src/solver/BestResponse.cpp +++ b/src/solver/BestResponse.cpp @@ -130,7 +130,7 @@ BestResponse::chanceBestReponse(shared_ptr node, int player,const ve vector> results(node->getCards().size()); #pragma omp parallel for - for(std::size_t card = 0;card < node->getCards().size();card ++) { + for(std::int64_t card = 0;card < node->getCards().size();card ++) { shared_ptr one_child = node->getChildren(); Card one_card = node->getCards()[card]; uint64_t card_long = Card::boardInt2long(one_card.getCardInt()); diff --git a/src/solver/PCfrSolver.cpp b/src/solver/PCfrSolver.cpp index 29395ed..d73aa53 100644 --- a/src/solver/PCfrSolver.cpp +++ b/src/solver/PCfrSolver.cpp @@ -299,7 +299,7 @@ PCfrSolver::chanceUtility(int player, shared_ptr node, const vector< } #pragma omp parallel for schedule(static) - for(std::size_t valid_ind = 0;valid_ind < valid_cards.size();valid_ind++) { + for(std::int64_t valid_ind = 0;valid_ind < valid_cards.size();valid_ind++) { int card = valid_cards[valid_ind]; shared_ptr one_child = node->getChildren(); Card *one_card = const_cast(&(node->getCards()[card])); diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp new file mode 100644 index 0000000..138d9f6 --- /dev/null +++ b/src/solver/slice_cfr.cpp @@ -0,0 +1,867 @@ +#include "solver/slice_cfr.h" +#include "ranges/RiverRangeManager.h" + +// 数组poss_card的索引[0,51]-->[1,52],8位二进制编码,最多选两个,占用高16位,低16位预留其他用途 +#define code_idx0(i) (((i)+1)<<24) +#define decode_idx0(x) (((x)>>24) - 1) +#define code_idx1(i) (((i)+1)<<16) +#define decode_idx1(x) ((((x)>>16)&0xff) - 1) + +void print_array(int *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%d", arr[0]); + for(int i = 1; i < n; i++) printf(",%d", arr[i]); + } + printf("\n"); +} + +void test_parallel_for(int n_thread, int n = 100) { + vector cnt(n_thread); + #pragma omp parallel for + for(int i = 0; i < n; i++) { + cnt[omp_get_thread_num()]++; + } + print_array(cnt.data(), n_thread); +} + +inline bool cards_valid(size_t hash1, size_t hash2) { + return (hash1 & hash2) == 0; +} + +typedef void (*node_func)(Node *, int); + +void rm_avg(Node *node, int n_hand) { + int size = node->n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += data[i]; + data[sum_offset+h] = sum; + } +} +void rm(Node *node, int n_hand) { + int size = node->n_act * n_hand; + int i = 0, h = 0, sum_offset = size * 3; + float *data = node->data + size;// regret_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += max(0.0f, data[i]); + data[sum_offset+h] = sum; + } +} +void reach_prob_avg(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[size+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[size+i] = temp * data[i]; + } + } +} +void reach_prob(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, rp_offset = size << 1, sum_offset = rp_offset + size; + float *data = node->data + size;// regret_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp * max(0.0f, data[i]); + } + } +} +// 子节点cfv取最大值 +void best_cfv_up(Node *node, int n_hand) { + int size = node->n_act * n_hand; + int i = 0, h = 0; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + mutex *mtx = node->mtx; + for(h = 0; h < n_hand; h++) { + val = cfv[h];// 第一个 + for(i = h+n_hand; i < size; i += n_hand) val = max(val, cfv[i]); + mtx->lock(); + parent_cfv[h] += val;// 需要加锁 + mtx->unlock(); + } +} +// 子节点cfv加权求和 +void cfv_up(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 2; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + float *regret_sum = cfv + size; + mutex *mtx = node->mtx; + for(h = 0; h < n_hand; h++) { + val = 0; + if(cfv[sum_offset+h] == 0) { + for(i = h; i < size; i += n_hand) val += cfv[i]; + val /= n_act;// uniform strategy + } + else { + for(i = h; i < size; i += n_hand) { + val += cfv[i] * max(0.0f, regret_sum[i]); + } + val /= cfv[sum_offset+h]; + } + // cfv[sum_offset+h] = val; + mtx->lock(); + parent_cfv[h] += val;// 需要加锁 + mtx->unlock(); + for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum + val = 0; + for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); + cfv[sum_offset+h] = val;// 求和 + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} +// 在cfv_up前执行 +void updata_data(Node *node, int n_hand, float pos_coef, float neg_coef, float coef) { + int size = node->n_act * n_hand, i = 0; + float *regret_sum = node->data + size, *strategy_sum = regret_sum + size; + for(i = 0; i < size; i++) { + regret_sum[i] *= regret_sum[i] > 0 ? pos_coef : neg_coef; + strategy_sum[i] = strategy_sum[i] * coef + strategy_sum[size+i]; + } +} + +// #define TIME_LOG +#ifdef TIME_LOG +atomic_ullong fold_time[16] = {0}, sd_time[16] = {0}; +#endif + +void SliceCFR::leaf_cfv(int player) { +#ifdef TIME_LOG + Timer timer; + for(int i = 0; i < n_thread; i++) { + fold_time[i].store(0), sd_time[i].store(0); + } +#endif + int opp = 1 - player; + int my_hand = hand_size[player], opp_hand = hand_size[opp]; + vector &vec = pre_leaf_node[player]; + int64_t n = vec.size(); + #pragma omp parallel for schedule(dynamic) + // #pragma omp parallel for + for(int64_t i = 0; i < n; i++) { + // printf("omp_get_thread_num():%d,%zd\n", omp_get_thread_num(), i); + float *cfv = vec[i].cfv; + // for(int j = 0; j < my_hand; j++) cfv[j] = 0; + for(int j : vec[i].leaf_node_idx) { + LeafNode &node = leaf_node[j]; + if(j < sd_offset) { + fold_cfv(player, cfv, node.reach_prob[opp], my_hand, opp_hand, ev_ptr[j], node.info); + } + else sd_cfv(player, cfv, node.reach_prob[opp], my_hand, opp_hand, ev_ptr[j], node.info); + } + } +#ifdef TIME_LOG + for(int i = 0; i < n_thread; i++) { + printf("%zd\t%zd\n", fold_time[i].load(), sd_time[i].load()); + } + // printf("leaf_cfv:%zd ms\n", timer.ms()); +#endif +} +void SliceCFR::fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, size_t board) { +#ifdef TIME_LOG + Timer timer; +#endif + if(player != P0) val = -val; + size_t *my_hash = hand_hash_ptr[player], *opp_hash = hand_hash_ptr[1-player]; + int *my_card = hand_card_ptr[player], *opp_card = hand_card_ptr[1-player]; + int *same_hand = same_hand_ptr[player], i = 0; + vector opp_prob_sum(n_card, 0); + float prob_sum = 0, temp = 0; + for(i = 0; i < opp_hand; i++) { + if(opp_hash[i] & board) continue;// 对方手牌与公共牌冲突 + temp = opp_reach[i]; + opp_prob_sum[opp_card[i]] += temp;// card1 + opp_prob_sum[opp_card[i+opp_hand]] += temp;// card2 + prob_sum += temp; + } + for(i = 0; i < my_hand; i++) { + if(my_hash[i] & board) { + // cfv[i] = 0;// 与公共牌冲突,cfv为0 + continue; + } + temp = same_hand[i] != -1 ? opp_reach[same_hand[i]] : 0;// 重复计算的部分 + cfv[i] += (prob_sum - opp_prob_sum[my_card[i]] - opp_prob_sum[my_card[i+my_hand]] + temp) * val; + } +#ifdef TIME_LOG + fold_time[omp_get_thread_num()] += timer.us(); +#endif +} +void SliceCFR::sd_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, int idx) { +#ifdef TIME_LOG + Timer timer; +#endif + vector &vec = strength[idx]; + const RiverCombs *my_data = vec[player].data, *opp_data = vec[1-player].data; + int my_size = vec[player].size, opp_size = vec[1-player].size, i = 0, j = 0, h = 0, rank = 0; + int *my_card = hand_card_ptr[player], *opp_card = hand_card_ptr[1-player]; + vector opp_prob_sum(n_card, 0); + float prob_sum = 0; + for(i = 0, j = 0; i < my_size; i++) {// strength值变小,己方手牌变强 + rank = my_data[i].rank; + for(; j < opp_size && opp_data[j].rank > rank; j++) {// (胜过对方条件下)找到对方的最强手牌 + h = opp_data[j].reach_prob_index; + opp_prob_sum[opp_card[h]] += opp_reach[h];// card1 + opp_prob_sum[opp_card[h+opp_hand]] += opp_reach[h];// card2 + prob_sum += opp_reach[h]; + } + h = my_data[i].reach_prob_index; + cfv[h] += (prob_sum - opp_prob_sum[my_card[h]] - opp_prob_sum[my_card[h+my_hand]]) * val; + } + prob_sum = 0; + for(h = 0; h < n_card; h++) opp_prob_sum[h] = 0; + for(i = my_size-1, j = opp_size-1; i >= 0; i--) {// strength值变大,己方手牌变弱 + rank = my_data[i].rank; + for(; j >= 0 && opp_data[j].rank < rank; j--) {// (败给对方条件下)找到对方的最弱手牌 + h = opp_data[j].reach_prob_index; + opp_prob_sum[opp_card[h]] += opp_reach[h];// card1 + opp_prob_sum[opp_card[h+opp_hand]] += opp_reach[h];// card2 + prob_sum += opp_reach[h]; + } + h = my_data[i].reach_prob_index; + cfv[h] += (opp_prob_sum[my_card[h]] + opp_prob_sum[my_card[h+my_hand]] - prob_sum) * val; + } +#ifdef TIME_LOG + sd_time[omp_get_thread_num()] += timer.us(); +#endif +} +void SliceCFR::append_node_idx(int p_idx, int act_idx, int player, int leaf_node_idx) { + if(p_idx == -1) { + root_child_idx[player].push_back(leaf_node_idx); + leaf_node[leaf_node_idx].reach_prob[player] = root_prob_ptr[player]; + return; + } + vector &vec = pre_leaf_node[player]; + int n_hand = hand_size[player], offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, n_hand); + float *cfv = player_node[dfs_idx_map[p_idx]].data + cfv_offset(n_hand, act_idx); + if(pre_leaf_node_map[p_idx].empty()) pre_leaf_node_map[p_idx] = vector(dfs_node[p_idx].n_act, -1); + int &i = pre_leaf_node_map[p_idx][act_idx]; + if(i == -1) {// 未初始化 + i = vec.size(); + vec.emplace_back(cfv); + } + vec[i].leaf_node_idx.push_back(leaf_node_idx); + leaf_node[leaf_node_idx].reach_prob[player] = cfv + offset; +} +size_t SliceCFR::init_leaf_node() { + pre_leaf_node_map = vector>(dfs_idx); + pre_leaf_node = vector>(N_PLAYER); + root_child_idx = vector>(N_PLAYER); + leaf_node = vector(n_leaf_node); + int node_idx = 0; + for(int i = 0; i < N_LEAF_TYPE; i++) { + for(int idx : leaf_node_dfs[i]) { + DFSNode &node = dfs_node[idx]; + append_node_idx(node.parent_p0_idx, node.parent_p0_act, P0, node_idx); + append_node_idx(node.parent_p1_idx, node.parent_p1_act, P1, node_idx); + int j = decode_idx0(node.info), k = decode_idx1(node.info); + size_t info = init_board; + if(i == FOLD_TYPE) { + if(j != -1) info |= 1LL << poss_card[j]; + if(k != -1) info |= 1LL << poss_card[k]; + } + else { + if(j == -1) info = 0; + else if(k == -1) info = j; + else info = tril_idx(j, k); + } + leaf_node[node_idx++].info = info; + } + } + sd_offset = leaf_node_dfs[FOLD_TYPE].size(); + printf("%zd,%zd\n", pre_leaf_node[P0].size(), pre_leaf_node[P1].size()); + printf("%d,%d,%zd,%zd\n", n_leaf_node, node_idx, root_child_idx[P0].size(), root_child_idx[P1].size()); + + size_t max_val[N_PLAYER] = {0, 0}, min_val[N_PLAYER] = {INT_MAX, INT_MAX}; + for(int i = 0; i < N_PLAYER; i++) { + if(!root_child_idx[i].empty()) { + pre_leaf_node[i].emplace_back(root_cfv_ptr[i]); + pre_leaf_node[i].back().leaf_node_idx = root_child_idx[i]; + } + for(PreLeafNode &node : pre_leaf_node[i]) { + assert(node.cfv != nullptr); + max_val[i] = max(max_val[i], node.leaf_node_idx.size()); + min_val[i] = min(min_val[i], node.leaf_node_idx.size()); + } + } + printf("%zd,%zd,%zd,%zd\n", min_val[P0], max_val[P0], min_val[P1], max_val[P1]); + + ev[FOLD_TYPE].insert(ev[FOLD_TYPE].end(), ev[SHOWDOWN_TYPE].begin(), ev[SHOWDOWN_TYPE].end()); + ev[FOLD_TYPE].clear(); + ev_ptr = ev[FOLD_TYPE].data(); + size_t total = n_leaf_node * sizeof(LeafNode); + total += (pre_leaf_node[P0].size() + pre_leaf_node[P1].size()) * sizeof(PreLeafNode); + total += n_leaf_node * N_PLAYER * sizeof(int);// leaf_node_idx + return total; +} + +SliceCFR::SliceCFR( + shared_ptr tree, + vector &range1, + vector &range2, + vector &initial_board, + shared_ptr compairer, + Deck &deck, + int train_step, + int print_interval, + float accuracy, + int n_thread +):tree(tree), deck(deck), steps(train_step), interval(print_interval), n_thread(max(0,n_thread)), rrm(compairer) { + init_board = Card::boardInts2long(initial_board); + init_round = GameTreeNode::gameRound2int(tree->getRoot()->getRound()); + if(init_round < FLOP_ROUND) return; + init_hand_card(range1, range2); + if(hand_size[P0] == 0 || hand_size[P1] == 0) return; + init_same_hand_idx(); + init_min_card(); + init_poss_card(deck, init_board); + normalization(); + tol = accuracy / 100 * tree->getRoot()->getPot(); + if(this->n_thread == 0) this->n_thread = omp_get_num_procs(); + omp_set_num_threads(this->n_thread); + // test_parallel_for(this->n_thread); + + float unit = 1 << 20; + size_t size = estimate_tree_size(); + printf("estimate memory:%f MB\n", size/unit); + + leaf_node_dfs.resize(N_LEAF_TYPE); + ev.resize(N_LEAF_TYPE); + slice.resize(N_PLAYER); + dfs_idx = 0; + dfs(tree->getRoot(), -1, -1, -1, -1, -1, -1, 0, 0, 0); + + print_array(node_cnt, N_TYPE); + for(int i = 0; i < N_LEAF_TYPE; i++) { + printf("%zd,", leaf_node_dfs[i].size()); + assert(node_cnt[i] == leaf_node_dfs[i].size()); + } + for(int player = P0; player < N_PLAYER; player++) { + size = 0; + for(vector &nodes : slice[player]) size += nodes.size(); + printf("%zd,", size); + assert(size == node_cnt[N_LEAF_TYPE+player]); + } + printf("%zd\n", chance_node.size()); + assert(node_cnt[N_LEAF_TYPE+CHANCE_PLAYER] == chance_node.size()); + + if(dfs_idx == 0 || dfs_node[0].n_act == 0) return; + size = init_memory(compairer); + printf("%d nodes, total:%f MB\n", dfs_idx, size/unit); + init_succ = true; +} + +SliceCFR::~SliceCFR() { + for(Node &node : player_node) { + if(node.data) free(node.data); + } +} + +void SliceCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset, mutex *&mtx) { + if(player == -1) player = node.player;// 向上连接同玩家节点 + int p_idx = node.parent_p0_idx, act_idx = node.parent_p0_act;// 向上连接P0 + if(player != P0) {// 向上连接P1 + p_idx = node.parent_p1_idx; + act_idx = node.parent_p1_act; + } + if(p_idx == -1) { + cfv = root_cfv_ptr[player]; + offset = root_prob_ptr[player] - root_cfv_ptr[player]; + mtx = (mutex *)player; + } + else { + if(player != dfs_node[p_idx].player) throw runtime_error("player mismatch"); + cfv = player_node[dfs_idx_map[p_idx]].data + cfv_offset(hand_size[player], act_idx); + offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, hand_size[player]); + if(mtx_map[p_idx].empty()) mtx_map[p_idx] = vector(dfs_node[p_idx].n_act, -1); + int &i = mtx_map[p_idx][act_idx]; + if(i == -1) i = mtx_idx++; + mtx = (mutex *)i; + } +} + +size_t SliceCFR::init_player_node() { + size_t total = 0, size = 0; + player_node = vector(n_player_node); + player_node_ptr = player_node.data(); + dfs_idx_map = vector(dfs_idx, -1); + slice_offset = vector>(N_PLAYER); + mtx_map = vector>(dfs_idx); + mtx_idx = N_PLAYER; + int mem_idx = 0; + for(int i = 0; i < N_PLAYER; i++) {// 枚举player + for(vector &nodes : slice[i]) {// 枚举slice + slice_offset[i].push_back(mem_idx); + for(int idx : nodes) {// 枚举node + DFSNode &node = dfs_node[idx]; + Node &target = player_node[mem_idx]; + target.n_act = node.n_act; + set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset, target.mtx); + size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); + target.data = (float *)malloc(size); + if(target.data == nullptr) throw runtime_error("malloc error"); + total += size; + dfs_idx_map[idx] = mem_idx++; + } + } + slice_offset[i].push_back(mem_idx); + } + mtx = vector(mtx_idx); + printf("%d,%d,%d\n", sizeof(mutex), mtx_idx, mtx_idx * sizeof(mutex)); + total += mtx_idx * sizeof(mutex); + for(int i : dfs_idx_map) { + if(i == -1) continue; + player_node[i].mtx = &mtx[(int)(player_node[i].mtx)]; + } + total += n_player_node * sizeof(Node); + return total; +} + +size_t SliceCFR::init_memory(shared_ptr compairer) { + size_t total = 0; + int n = root_prob.size(); + root_cfv = vector(n<<1, 0); + for(int i = 0; i < n; i++) root_cfv[n+i] = root_prob[i]; + total += n * 3 * sizeof(float); + root_cfv_ptr[P0] = root_cfv.data(); + root_cfv_ptr[P1] = root_cfv_ptr[P0] + hand_size[P0]; + root_prob_ptr[P0] = root_cfv_ptr[P0] + n; + root_prob_ptr[P1] = root_prob_ptr[P0] + hand_size[P0]; + + total += init_player_node(); + total += init_leaf_node(); + total += init_strength_table(compairer); + return total; +} + +size_t SliceCFR::init_strength_table(shared_ptr compairer) { + int n = poss_card.size(); + vector board_hash; + if(init_round == RIVER_ROUND) board_hash.push_back(init_board); + else if(init_round == TURN_ROUND) { + board_hash = vector(n, 0); + for(int i = 0; i < n; i++) board_hash[i] = init_board | (1LL<(n*(n-1)>>1, 0); + for(int i = 0; i < n; i++) { + for(int j = i+1; j < n; j++) { + board_hash[tril_idx(j, i)] = init_board | two_card_hash(poss_card[i], poss_card[j]); + } + } + } + n = board_hash.size(); + strength = vector>(n); + // omp_set_num_threads(omp_get_num_procs()); + #pragma omp parallel for + for(int i = 0; i < n; i++) { + // printf("omp_get_thread_num():%d,%d\n", omp_get_thread_num(), i); + const vector& p0_comb = rrm.getRiverCombos(P0, ranges[P0], board_hash[i]); + const vector& p1_comb = rrm.getRiverCombos(P1, ranges[P1], board_hash[i]); + strength[i].emplace_back(p0_comb.size(), p0_comb.data()); + strength[i].emplace_back(p1_comb.size(), p1_comb.data()); + } + size_t total = (n<<1) * sizeof(StrengthData), size = 0; + for(int i = 0; i < n; i++) size += strength[i][P0].size + strength[i][P1].size; + total += (size<<1) * sizeof(int);// rank,idx + return total; +} + +void SliceCFR::init_min_card() { + min_card = N_CARD; + int max_card = -1; + for(int card : hand_card) { + min_card = min(min_card, card); + max_card = max(max_card, card); + } + n_card = max_card - min_card + 1;// 52张牌中如果只用了连续的一段,可以节省内存 + for(int &card : hand_card) card -= min_card; +} + +void SliceCFR::init_hand_card(vector &range1, vector &range2) { + ranges = vector>(2); + vector cards;// card1,card2,card1,card2,... + init_hand_card(range1, cards, root_prob, init_board, ranges[P0]); + hand_size[P0] = root_prob.size(); + init_hand_card(range2, cards, root_prob, init_board, ranges[P1]); + hand_size[P1] = root_prob.size() - hand_size[P0]; + hand_card = vector(cards.size()); + hand_hash = vector(root_prob.size()); + int stop[N_PLAYER] = {hand_size[P0]<<1, cards.size()}; + int i = 0, j = 0, k = 0, n = 0; + for(int p = 0; p < N_PLAYER; p++) { + for(n = hand_size[p], i = j; j < stop[p]; j += 2, i++) { + hand_card[i] = cards[j]; + hand_card[i+n] = cards[j+1]; + hand_hash[k++] = two_card_hash(cards[j], cards[j+1]); + } + } + hand_card_ptr[P0] = hand_card.data(); + hand_card_ptr[P1] = hand_card_ptr[P0] + stop[P0]; + hand_hash_ptr[P0] = hand_hash.data(); + hand_hash_ptr[P1] = hand_hash_ptr[P0] + hand_size[P0]; +} + +void SliceCFR::init_hand_card(vector &range, vector &cards, vector &prob, size_t board, vector &out) { + unordered_set seen; + for(PrivateCards &hand : range) { + size_t hash = hand.toBoardLong(); + if(seen.count(hash)) continue;// 去重 + if(hash & board) continue;// 和公共牌冲突 + seen.insert(hash); + cards.push_back(min(hand.card1, hand.card2)); + cards.push_back(max(hand.card1, hand.card2)); + prob.push_back(hand.weight); + out.push_back(hand); + } +} + +void SliceCFR::init_same_hand_idx() { + int n = root_prob.size(), p0_size = hand_size[P0]; + same_hand_idx = vector(n, -1); + unordered_map hash2idx; + for(int h = 0; h < p0_size; h++) hash2idx[hand_hash[h]] = h; + for(int h = p0_size; h < n; h++) {// P1 + size_t hash = hand_hash[h]; + if(hash2idx.count(hash)) { + same_hand_idx[h] = hash2idx[hash]; + same_hand_idx[hash2idx[hash]] = h - p0_size; + } + } + same_hand_ptr[P0] = same_hand_idx.data(); + same_hand_ptr[P1] = same_hand_ptr[P0] + p0_size; +} + +void SliceCFR::normalization() { + int p0_size = hand_size[P0], n = root_prob.size(); + norm = 0; + // 每个history的概率为p0_prob*p1_prob*chance_prob*mask/norm + // p0手牌,p1手牌,公共牌之间有冲突时mask=0,无冲突时mask=1 + // cfr迭代过程中,不需要考虑norm + for(int p0 = 0; p0 < p0_size; p0++) { + for(int p1 = p0_size; p1 < n; p1++) { + if(!cards_valid(hand_hash[p0], hand_hash[p1])) continue; + norm += root_prob[p0] * root_prob[p1]; + } + } +} + +size_t SliceCFR::estimate_tree_size() { + for(int i = 0; i < N_TYPE; i++) node_cnt[i] = 0; + if(tree == nullptr) return 0; + size_t size = _estimate_tree_size(tree->getRoot()); + n_leaf_node = node_cnt[FOLD_TYPE] + node_cnt[SHOWDOWN_TYPE]; + n_player_node = node_cnt[N_LEAF_TYPE+P0] + node_cnt[N_LEAF_TYPE+P1]; + size *= sizeof(float); + size += n_leaf_node * sizeof(LeafNode); + size += n_player_node * sizeof(Node); + return size; +} + +size_t SliceCFR::_estimate_tree_size(shared_ptr node) { + int type = node->getType(), round = GameTreeNode::gameRound2int(node->getRound()), n_act = 0; + size_t size = 0; + if(type == GameTreeNode::ACTION) { + shared_ptr act_node = dynamic_pointer_cast(node); + vector> children = act_node->getChildrens(); + n_act = children.size(); + int player = act_node->getPlayer(); + node_cnt[N_LEAF_TYPE + player]++; + size += get_size(n_act, hand_size[player]); + for(int i = 0; i < n_act; i++) size += _estimate_tree_size(children[i]); + } + else if(type == GameTreeNode::CHANCE) { + shared_ptr chance_node = dynamic_pointer_cast(node); + shared_ptr children = chance_node->getChildren();// 不为null + int child_type = children->getType(); + n_act = chance_branch[round] + 4; + node_cnt[N_LEAF_TYPE + CHANCE_PLAYER]++; + if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) { + for(int i = 0; i < n_act; i++) size += _estimate_tree_size(children); + } + else {// CHANCE之后接着CHANCE,再接着SHOWDOWN + node_cnt[SHOWDOWN_TYPE] += (n_act*(n_act-1)>>1); + } + } + else if(type == GameTreeNode::SHOWDOWN) node_cnt[SHOWDOWN_TYPE]++; + else node_cnt[FOLD_TYPE]++; + return size; +} + +void SliceCFR::dfs(shared_ptr node, int parent_act, int parent_dfs_idx, int parent_p0_act, int parent_p0_idx, int parent_p1_act, int parent_p1_idx, int cnt0, int cnt1, int info) { + int curr_idx = dfs_idx++; + int type = node->getType(), round = GameTreeNode::gameRound2int(node->getRound()), n_act = 0; + if(type == GameTreeNode::ACTION) { + shared_ptr act_node = dynamic_pointer_cast(node); + int player = act_node->getPlayer(); + vector> children = act_node->getChildrens(); + n_act = children.size(); + dfs_node.emplace_back(player, n_act, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + vector> &player_slice = slice[player]; + if(player == P0) { + if(player_slice.size() == cnt0) player_slice.emplace_back(); + player_slice[cnt0++].push_back(curr_idx); + for(int i = 0; i < n_act; i++) dfs(children[i], i, curr_idx, i, curr_idx, parent_p1_act, parent_p1_idx, cnt0, cnt1, info); + } + else {// P1 + if(player_slice.size() == cnt1) player_slice.emplace_back(); + player_slice[cnt1++].push_back(curr_idx); + for(int i = 0; i < n_act; i++) dfs(children[i], i, curr_idx, parent_p0_act, parent_p0_idx, i, curr_idx, cnt0, cnt1, info); + } + } + else if(type == GameTreeNode::CHANCE) { + shared_ptr chance_node = dynamic_pointer_cast(node); + shared_ptr children = chance_node->getChildren();// 不为null + int child_type = children->getType(); + n_act = chance_branch[round] + 4; + this->chance_node.push_back(curr_idx); + if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) {// 需要发1张牌 + dfs_node.emplace_back(CHANCE_PLAYER, n_act, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + // 发牌信息编码,只有1张牌时,占用idx0,有2张牌时,索引较大的占用idx0,较小的占用idx1 + int j = decode_idx0(info), new_info = 0; + for(int i = 0, k = 0; i < n_act; i++, k++) {// 动作索引i,poss_card索引k + if(j == -1) new_info = code_idx0(k);// 第一次发牌 + else {// 第二次发牌,最多发两次牌 + if(k == j) k++;// 两次选的一样,则第二次改成下一个 + new_info = code_idx0(max(j,k)) | code_idx1(min(j,k));// idx0为较大值 + } + dfs(children, i, curr_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx, cnt0, cnt1, new_info); + } + } + else {// CHANCE之后接着CHANCE,再接着SHOWDOWN,需要连续发2张牌 + assert(round == TURN_ROUND); + shared_ptr child = dynamic_pointer_cast(children); + assert(child->getChildren()->getType() == GameTreeNode::SHOWDOWN); + int parent_player = dfs_node[parent_dfs_idx].player;// 父节点玩家 + dfs_node.emplace_back(CHANCE_PLAYER, n_act*(n_act-1)>>1, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + // float val = node->getPot()/2*2/chance_den[RIVER_ROUND]; + float val = node->getPot()/chance_den[RIVER_ROUND]; + for(int i = 0, j = 0; j < n_act; j++) { + for(int k = j+1; k < n_act; k++) { + ev[SHOWDOWN_TYPE].push_back(val); + leaf_node_dfs[SHOWDOWN_TYPE].push_back(dfs_idx++); + info = code_idx0(k) | code_idx1(j);// idx0为较大值 + dfs_node.emplace_back(CHANCE_PLAYER, 0, i++, info, curr_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + } + } + } + } + else {// river SHOWDOWN, fold + assert(parent_dfs_idx != -1); + int parent_player = dfs_node[parent_dfs_idx].player;// 父节点玩家 + int i = SHOWDOWN_TYPE; + float val = 0; + if(type == GameTreeNode::SHOWDOWN) val = node->getPot()/2; + else {// fold + vector pot = dynamic_pointer_cast(node)->get_payoffs(); + val = pot[P0]; + i = FOLD_TYPE; + } + leaf_node_dfs[i].push_back(curr_idx); + ev[i].push_back(val / chance_den[round]); + dfs_node.emplace_back(parent_player, 0, parent_act, info, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + } +} + +void SliceCFR::init_poss_card(Deck& deck, size_t board) { + vector &cards = deck.getCards(); + for(Card& card : cards) { + int i = card.getCardInt(); + if(cards_valid(1LL<= 0; r--) chance_branch[r] = poss_card.size() - 4;// 排除2个玩家的手牌,总共4张 + for(int r = init_round+2; r < N_ROUND; r++) chance_branch[r] = chance_branch[r-1] - 1; + print_array(chance_branch, N_ROUND); + for(int r = 0; r <= init_round; r++) chance_den[r] = 1; + for(int r = init_round+1; r < N_ROUND; r++) chance_den[r] = chance_den[r-1] * chance_branch[r]; + print_array(chance_den, N_ROUND); +} + +void SliceCFR::_reach_prob(int player, bool best_cfv) { + vector& offset = slice_offset[player]; + int n = offset.size(), n_hand = hand_size[player]; + node_func func = best_cfv ? reach_prob_avg : reach_prob; + for(int i = 1; i < n; i++) { + #pragma omp parallel for + for(int j = offset[i-1]; j < offset[i]; j++) { + func(player_node_ptr+j, n_hand); + } + } +} +void SliceCFR::_rm(int player, bool best_cfv) { + node_func func = best_cfv ? rm_avg : rm; + int s = slice_offset[player][0], e = slice_offset[player].back(), n_hand = hand_size[player]; + #pragma omp parallel for + for(int i = s; i < e; i++) { + func(player_node_ptr+i, n_hand); + } +} + +void SliceCFR::clear_data(int player) { + int s = slice_offset[player][0], e = slice_offset[player].back(), n_hand = hand_size[player]; + size_t size = 0; + for(int i = s; i < e; i++) { + size = get_size(player_node_ptr[i].n_act, n_hand) * sizeof(float); + memset(player_node_ptr[i].data, 0, size); + } +} + +void SliceCFR::clear_root_cfv() { + size_t size = root_prob.size() * sizeof(float); + memset(root_cfv_ptr[P0], 0, size); +} + +void SliceCFR::train() { + if(!init_succ) return; + size_t start = timeSinceEpochMillisec(), total = 0; + Timer timer; + clear_data(P0); + clear_data(P1); + // _rm(P0, false); + // _rm(P1, false); + // _reach_prob(P0, false); + vector res = exploitability(); + printf("0:%f %f %f\n", res[0], res[1], (res[0]+res[1])/2); + // 计算exploitability后,双方的rm和p0的reach_prob已经恢复 + pos_coef = neg_coef = coef = 0; + double temp = 0; + int cnt = 0, iter = 1; + for(iter = 1; iter <= steps; iter++) { + clear_root_cfv(); + for(int player = P0; player < N_PLAYER; player++) { + step(iter, player, false); + } + temp = pow(iter, alpha); + pos_coef = temp / (temp + 1); + temp = pow(iter, beta); + neg_coef = temp / (temp + 1); + // neg_coef = 0.5; + coef = pow((float)iter/(iter+1), gamma); + if((++cnt) == interval) { + cnt = 0; + res = exploitability(); + total = timeSinceEpochMillisec() - start; + printf("%d:%.3f,%.3fs\n", iter, timer.ms()/1000.0, total/1000.0); + temp = (res[0] + res[1]) / 2; + printf("%d:%f %f %f\n", iter, res[0], res[1], temp); + if(temp <= tol) break; + } + if(stop_flag) break; + } + if(cnt) { + res = exploitability(); + total = timeSinceEpochMillisec() - start; + printf("%d:%.3f,%.3fs\n", iter, timer.ms()/1000.0, total/1000.0); + printf("%d:%f %f %f\n", iter, res[0], res[1], (res[0]+res[1])/2); + } +} + +// player到达概率已经计算好 +void SliceCFR::step(int iter, int player, bool best_cfv) { +#ifdef TIME_LOG + size_t start = timeSinceEpochMillisec(), end = 0; +#endif + int opp = 1 - player, my_hand = hand_size[player]; + _reach_prob(opp, best_cfv); +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t1 = end - start; + start = end; +#endif + + leaf_cfv(player); +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t2 = end - start; + start = end; +#endif + + vector& offset = slice_offset[player]; + if(!best_cfv) { + #pragma omp parallel for + for(int j = offset[0]; j < offset.back(); j++) { + updata_data(player_node_ptr+j, my_hand, pos_coef, neg_coef, coef); + } + } +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t3 = end - start; + start = end; +#endif + + node_func func = best_cfv ? best_cfv_up : cfv_up; + for(int i = offset.size()-1; i > 0; i--) { + #pragma omp parallel for + for(int j = offset[i-1]; j < offset[i]; j++) { + func(player_node_ptr+j, my_hand); + } + } +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t4 = end - start; + printf("%zd\t%zd\t%zd\t%zd\n", t1, t2, t3, t4); +#endif +} + +vector SliceCFR::exploitability() { + int opp = 0; + clear_root_cfv(); + for(int player = P0; player < N_PLAYER; player++) { +#ifdef TIME_LOG + size_t start = timeSinceEpochMillisec(); +#endif + opp = 1 - player; + _rm(opp, true);// 改变对方策略 +#ifdef TIME_LOG + size_t t1 = timeSinceEpochMillisec() - start; +#endif + step(0, player, true); +#ifdef TIME_LOG + start = timeSinceEpochMillisec(); +#endif + _rm(opp, false);// 恢复对方策略 +#ifdef TIME_LOG + size_t t2 = timeSinceEpochMillisec() - start; + printf("rm time:%zd\t%zd\n", t1, t2); +#endif + } + _reach_prob(P0, false);// 恢复P0的reach_prob,用于下一次迭代 + int m = 0, n = hand_size[P0]; + float ev0 = 0, ev1 = 0; + for(int i = m; i < n; i++) ev0 += root_cfv[i] * root_prob[i]; + m = n; n = root_prob.size(); + for(int i = m; i < n; i++) ev1 += root_cfv[i] * root_prob[i]; + return {ev0/norm, ev1/norm}; +} + +void SliceCFR::stop() { + stop_flag = true; +} +json SliceCFR::dumps(bool with_status, int depth) { + json ans; + return std::move(ans); +} +vector>> SliceCFR::get_strategy(shared_ptr node, vector cards) { + return {}; +} +vector>> SliceCFR::get_evs(shared_ptr node, vector cards) { + return {}; +} diff --git a/src/tools/CommandLineTool.cpp b/src/tools/CommandLineTool.cpp index 03e8108..a0cd885 100644 --- a/src/tools/CommandLineTool.cpp +++ b/src/tools/CommandLineTool.cpp @@ -92,6 +92,7 @@ void split(const string& s, char c, void CommandLineTool::processCommand(string input) { vector contents; + if(input.empty() || input[0] == '#') return; split(input,' ',contents); if(contents.size() == 0) contents = {input}; if(contents.size() > 2 || contents.size() < 1)throw runtime_error(tfm::format("command not valid: %s",input)); @@ -140,6 +141,8 @@ void CommandLineTool::processCommand(string input) { sizes->push_back(stof(params[i])); } } + }else if(command == "set_raise_limit"){ + this->raise_limit = stoi(paramstr); }else if(command == "set_accuracy"){ this->accuracy = stof(paramstr); }else if(command == "set_allin_threshold"){ @@ -168,13 +171,29 @@ void CommandLineTool::processCommand(string input) { this->accuracy, this->use_isomorphism, 0, // TODO: enable half float option for command line tool - this->thread_number + this->thread_number, + slice_cfr ); }else if(command == "dump_result"){ string output_file = paramstr; this->ps.dump_strategy(QString::fromStdString(output_file),this->dump_rounds); }else if(command == "set_dump_rounds"){ this->dump_rounds = stoi(paramstr); + }else if(command == "estimate_tree_memory"){ + if(range_ip.empty() || range_oop.empty() || board.empty()) { + cout << "Please set range_ip, range_oop and board first." << endl; + return; + } + shared_ptr game_tree = ps.get_game_tree(); + if(game_tree == nullptr) { + cout << "Please buld tree first." << endl; + return; + } + long long size = ps.estimate_tree_memory(range_ip, range_oop, board); + size *= sizeof(float); + cout << (float)size / (1024*1024) << " MB" << endl; + }else if(command == "set_slice_cfr"){ + slice_cfr = stoi(paramstr); }else{ cout << "command not recognized: " << command << endl; } From 690abdb449ab1ffac432cc6d5cd7a21538aece0f Mon Sep 17 00:00:00 2001 From: yffbit Date: Tue, 16 Apr 2024 21:12:42 +0800 Subject: [PATCH 02/19] update TexasSolverGui.pro --- TexasSolverGui.pro | 2 ++ 1 file changed, 2 insertions(+) diff --git a/TexasSolverGui.pro b/TexasSolverGui.pro index 2d6e4d8..6ef0b95 100644 --- a/TexasSolverGui.pro +++ b/TexasSolverGui.pro @@ -85,6 +85,7 @@ SOURCES += \ src/solver/CfrSolver.cpp \ src/solver/PCfrSolver.cpp \ src/solver/Solver.cpp \ + src/solver/slice_cfr.cpp \ src/tools/CommandLineTool.cpp \ src/tools/GameTreeBuildingSettings.cpp \ src/tools/lookup8.cpp \ @@ -137,6 +138,7 @@ HEADERS += \ include/solver/Solver.h \ include/solver/BestResponse.h \ include/solver/CfrSolver.h \ + include/solver/slice_cfr.h \ include/tools/argparse.hpp \ include/tools/CommandLineTool.h \ include/tools/utils.h \ From 9de2914a3edcaa85af8d5d14e6f44f399af7c569 Mon Sep 17 00:00:00 2001 From: yffbit Date: Tue, 16 Apr 2024 21:35:40 +0800 Subject: [PATCH 03/19] fix include error --- include/runtime/PokerSolver.h | 2 +- include/solver/slice_cfr.h | 6 +++--- src/solver/slice_cfr.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/runtime/PokerSolver.h b/include/runtime/PokerSolver.h index b454f65..1a73126 100644 --- a/include/runtime/PokerSolver.h +++ b/include/runtime/PokerSolver.h @@ -12,7 +12,7 @@ #include "include/solver/CfrSolver.h" #include "include/solver/PCfrSolver.h" #include "include/library.h" -#include "solver/slice_cfr.h" +#include "include/solver/slice_cfr.h" #include #include using namespace std; diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index a79df72..121dd93 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -5,8 +5,8 @@ #include #include #include -#include "nodes/GameTreeNode.h" -#include "solver/PCfrSolver.h" +#include "include/nodes/GameTreeNode.h" +#include "include/solver/PCfrSolver.h" #include #include @@ -102,7 +102,7 @@ class SliceCFR : public Solver { vector>> get_strategy(shared_ptr node, vector cards); vector>> get_evs(shared_ptr node, vector cards); private: - atomic_bool stop_flag = false; + atomic_bool stop_flag {false}; bool init_succ = false; int n_thread = 0; int thread_per_block = 32; diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 138d9f6..2824fbb 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -1,5 +1,5 @@ -#include "solver/slice_cfr.h" -#include "ranges/RiverRangeManager.h" +#include "include/solver/slice_cfr.h" +#include "include/ranges/RiverRangeManager.h" // 数组poss_card的索引[0,51]-->[1,52],8位二进制编码,最多选两个,占用高16位,低16位预留其他用途 #define code_idx0(i) (((i)+1)<<24) From adc1bcfe1be98953c68155bf2083bbe0dbbf5239 Mon Sep 17 00:00:00 2001 From: yffbit Date: Tue, 16 Apr 2024 21:56:34 +0800 Subject: [PATCH 04/19] update slice_cfr.cpp --- src/solver/slice_cfr.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 2824fbb..1ca367e 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -1,3 +1,4 @@ +#include #include "include/solver/slice_cfr.h" #include "include/ranges/RiverRangeManager.h" @@ -428,7 +429,7 @@ size_t SliceCFR::init_player_node() { total += mtx_idx * sizeof(mutex); for(int i : dfs_idx_map) { if(i == -1) continue; - player_node[i].mtx = &mtx[(int)(player_node[i].mtx)]; + player_node[i].mtx = &mtx[(size_t)(player_node[i].mtx)]; } total += n_player_node * sizeof(Node); return total; From ab6ed35840e622a3392d5de4d8b30eca39d15dcd Mon Sep 17 00:00:00 2001 From: yffbit Date: Mon, 22 Apr 2024 19:36:29 +0800 Subject: [PATCH 05/19] atomic_ref --- include/solver/slice_cfr.h | 11 ++++---- src/solver/slice_cfr.cpp | 53 +++++++++++++++++++++----------------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index 121dd93..36bc679 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -46,7 +46,7 @@ struct Node { int n_act = 0;// 动作数 int parent_offset = -1;// 本节点对应的父节点数据reach_prob的偏移量 float *parent_cfv = nullptr; - mutex *mtx = nullptr; + // mutex *mtx = nullptr; float *data = nullptr;// cfv,regret_sum,strategy_sum,reach_prob,sum }; struct LeafNode { @@ -149,7 +149,7 @@ class SliceCFR : public Solver { size_t init_memory(shared_ptr compairer); size_t init_player_node(); size_t init_leaf_node(); - void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset, mutex *&mtx); + void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset); void normalization(); size_t init_strength_table(shared_ptr compairer); void dfs(shared_ptr node, int parent_act=-1, int parent_dfs_idx=-1, int parent_p0_act=-1, int parent_p0_idx=-1, int parent_p1_act=-1, int parent_p1_idx=-1, int cnt0=0, int cnt1=0, int info=0); @@ -166,10 +166,9 @@ class SliceCFR : public Solver { vector player_node; Node *player_node_ptr = nullptr; int sd_offset = 0; - vector cpu_cfv; - vector mtx; - vector> mtx_map; - int mtx_idx = N_PLAYER; + // vector mtx; + // vector> mtx_map; + // int mtx_idx = N_PLAYER; vector> strength; size_t _estimate_tree_size(shared_ptr node); void _reach_prob(int player, bool best_cfv=false); diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 1ca367e..9e0103f 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -2,6 +2,9 @@ #include "include/solver/slice_cfr.h" #include "include/ranges/RiverRangeManager.h" +using std::memory_order_relaxed; +using std::atomic_ref; + // 数组poss_card的索引[0,51]-->[1,52],8位二进制编码,最多选两个,占用高16位,低16位预留其他用途 #define code_idx0(i) (((i)+1)<<24) #define decode_idx0(x) (((x)>>24) - 1) @@ -90,13 +93,14 @@ void best_cfv_up(Node *node, int n_hand) { int size = node->n_act * n_hand; int i = 0, h = 0; float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; - mutex *mtx = node->mtx; + // mutex *mtx = node->mtx; for(h = 0; h < n_hand; h++) { val = cfv[h];// 第一个 for(i = h+n_hand; i < size; i += n_hand) val = max(val, cfv[i]); - mtx->lock(); - parent_cfv[h] += val;// 需要加锁 - mtx->unlock(); + // mtx->lock(); + // parent_cfv[h] += val;// 需要加锁 + // mtx->unlock(); + atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); } } // 子节点cfv加权求和 @@ -105,7 +109,7 @@ void cfv_up(Node *node, int n_hand) { int i = 0, h = 0, sum_offset = size << 2; float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; float *regret_sum = cfv + size; - mutex *mtx = node->mtx; + // mutex *mtx = node->mtx; for(h = 0; h < n_hand; h++) { val = 0; if(cfv[sum_offset+h] == 0) { @@ -119,9 +123,10 @@ void cfv_up(Node *node, int n_hand) { val /= cfv[sum_offset+h]; } // cfv[sum_offset+h] = val; - mtx->lock(); - parent_cfv[h] += val;// 需要加锁 - mtx->unlock(); + // mtx->lock(); + // parent_cfv[h] += val;// 需要加锁 + // mtx->unlock(); + atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum val = 0; for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); @@ -375,7 +380,7 @@ SliceCFR::~SliceCFR() { } } -void SliceCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset, mutex *&mtx) { +void SliceCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset) { if(player == -1) player = node.player;// 向上连接同玩家节点 int p_idx = node.parent_p0_idx, act_idx = node.parent_p0_act;// 向上连接P0 if(player != P0) {// 向上连接P1 @@ -385,16 +390,16 @@ void SliceCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &o if(p_idx == -1) { cfv = root_cfv_ptr[player]; offset = root_prob_ptr[player] - root_cfv_ptr[player]; - mtx = (mutex *)player; + // mtx = (mutex *)player; } else { if(player != dfs_node[p_idx].player) throw runtime_error("player mismatch"); cfv = player_node[dfs_idx_map[p_idx]].data + cfv_offset(hand_size[player], act_idx); offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, hand_size[player]); - if(mtx_map[p_idx].empty()) mtx_map[p_idx] = vector(dfs_node[p_idx].n_act, -1); - int &i = mtx_map[p_idx][act_idx]; - if(i == -1) i = mtx_idx++; - mtx = (mutex *)i; + // if(mtx_map[p_idx].empty()) mtx_map[p_idx] = vector(dfs_node[p_idx].n_act, -1); + // int &i = mtx_map[p_idx][act_idx]; + // if(i == -1) i = mtx_idx++; + // mtx = (mutex *)i; } } @@ -404,8 +409,8 @@ size_t SliceCFR::init_player_node() { player_node_ptr = player_node.data(); dfs_idx_map = vector(dfs_idx, -1); slice_offset = vector>(N_PLAYER); - mtx_map = vector>(dfs_idx); - mtx_idx = N_PLAYER; + // mtx_map = vector>(dfs_idx); + // mtx_idx = N_PLAYER; int mem_idx = 0; for(int i = 0; i < N_PLAYER; i++) {// 枚举player for(vector &nodes : slice[i]) {// 枚举slice @@ -414,7 +419,7 @@ size_t SliceCFR::init_player_node() { DFSNode &node = dfs_node[idx]; Node &target = player_node[mem_idx]; target.n_act = node.n_act; - set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset, target.mtx); + set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset); size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); target.data = (float *)malloc(size); if(target.data == nullptr) throw runtime_error("malloc error"); @@ -424,13 +429,13 @@ size_t SliceCFR::init_player_node() { } slice_offset[i].push_back(mem_idx); } - mtx = vector(mtx_idx); - printf("%d,%d,%d\n", sizeof(mutex), mtx_idx, mtx_idx * sizeof(mutex)); - total += mtx_idx * sizeof(mutex); - for(int i : dfs_idx_map) { - if(i == -1) continue; - player_node[i].mtx = &mtx[(size_t)(player_node[i].mtx)]; - } + // mtx = vector(mtx_idx); + // printf("%d,%d,%d\n", sizeof(mutex), mtx_idx, mtx_idx * sizeof(mutex)); + // total += mtx_idx * sizeof(mutex); + // for(int i : dfs_idx_map) { + // if(i == -1) continue; + // player_node[i].mtx = &mtx[(size_t)(player_node[i].mtx)]; + // } total += n_player_node * sizeof(Node); return total; } From 6f47740ef031a6d5e85289df676a9aa8dce1d312 Mon Sep 17 00:00:00 2001 From: yffbit Date: Mon, 22 Apr 2024 19:42:13 +0800 Subject: [PATCH 06/19] update --- include/solver/slice_cfr.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index 36bc679..e20a6dc 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -8,6 +8,7 @@ #include "include/nodes/GameTreeNode.h" #include "include/solver/PCfrSolver.h" #include +#include #include using std::vector; From 34bea8120ff797d64836b6ae20cbb7fc97517dc9 Mon Sep 17 00:00:00 2001 From: yffbit Date: Mon, 22 Apr 2024 19:59:27 +0800 Subject: [PATCH 07/19] fix build error --- TexasSolverGui.pro | 1 + 1 file changed, 1 insertion(+) diff --git a/TexasSolverGui.pro b/TexasSolverGui.pro index 6ef0b95..d93539c 100644 --- a/TexasSolverGui.pro +++ b/TexasSolverGui.pro @@ -25,6 +25,7 @@ DEFINES += QT_DEPRECATED_WARNINGS TRANSLATIONS = lang_cn.ts\ lang_en.ts +CONFIG += c++2a macx: { QMAKE_CXXFLAGS += -Xpreprocessor -fopenmp -lomp -I/usr/local/include From ab7a22fbb97671f01412b1da50f3ccae5f256773 Mon Sep 17 00:00:00 2001 From: yffbit Date: Thu, 25 Apr 2024 22:56:56 +0800 Subject: [PATCH 08/19] dump strategy --- include/solver/slice_cfr.h | 4 +- src/solver/slice_cfr.cpp | 87 +++++++++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index e20a6dc..19c328c 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -128,7 +128,7 @@ class SliceCFR : public Solver { int *same_hand_ptr[N_PLAYER] {nullptr,nullptr}; vector> ranges; vector dfs_node; - vector dfs_idx_map;// dfs遍历的每个节点在cuda中的索引 + vector dfs_idx_map;// dfs遍历的每个节点在内存中的索引 int node_cnt[N_TYPE]; int n_leaf_node = 0; int n_player_node = 0; @@ -176,6 +176,8 @@ class SliceCFR : public Solver { void _rm(int player, bool best_cfv=false); void clear_data(int player); void clear_root_cfv(); + json reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info); + vector> get_avg_strategy(int idx); }; #endif // _SLICE_CFR_H_ diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 9e0103f..52c17df 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -861,8 +861,10 @@ vector SliceCFR::exploitability() { void SliceCFR::stop() { stop_flag = true; } -json SliceCFR::dumps(bool with_status, int depth) { - json ans; +json SliceCFR::dumps(bool with_status, int depth) {// depth:max_round + int idx = 0; + json ans = reConvertJson(tree->getRoot(), 0, depth, idx, 0); + if(idx != dfs_idx) throw runtime_error("dfs idx error"); return std::move(ans); } vector>> SliceCFR::get_strategy(shared_ptr node, vector cards) { @@ -871,3 +873,84 @@ vector>> SliceCFR::get_strategy(shared_ptr node vector>> SliceCFR::get_evs(shared_ptr node, vector cards) { return {}; } +vector> SliceCFR::get_avg_strategy(int idx) { + Node &node = player_node[dfs_idx_map[idx]]; + int n_hand = hand_size[dfs_node[idx].player], n_act = node.n_act; + int size = n_act * n_hand, i = 0, h = 0, j = 0; + float sum = 0, *strategy_sum = node.data + (size << 1), uni = 1.0 / n_act; + vector> strategy(n_hand, vector(n_act));// [n_hand,n_act] + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += strategy_sum[i]; + if(sum == 0) { + for(j = 0; j < n_act; j++) strategy[h][j] = uni; + } + else { + for(j = 0, i = h; j < n_act; j++, i += n_hand) strategy[h][j] = strategy_sum[i] / sum; + } + } + return strategy; +} +json SliceCFR::reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info) { + int curr_idx = idx++; + int type = node->getType(), n_act = 0; + json ans; + if(type == GameTreeNode::ACTION) { + shared_ptr one_node = dynamic_pointer_cast(node); + vector actions_str; + if(depth < max_depth) { + int player = one_node->getPlayer(); + for(GameActions one_action : one_node->getActions()) actions_str.push_back(one_action.toString()); + ans["actions"] = actions_str; + ans["player"] = player; + ans["node_type"] = "action_node"; + + vector> strategy = get_avg_strategy(curr_idx); + ans["strategy"] = json(); + ans["strategy"]["actions"] = actions_str; + json stt; + size_t n_hand = hand_size[player]; + int *ptr = hand_card_ptr[player]; + for(size_t i = 0; i < n_hand; i++) { + stt[Card::intCard2Str(ptr[i+n_hand])+Card::intCard2Str(ptr[i])] = strategy[i]; + } + ans["strategy"]["strategy"] = std::move(stt); + + ans["childrens"] = json(); + } + vector> children = one_node->getChildrens(); + n_act = children.size(); + for(int i = 0; i < n_act; i++) { + json child = reConvertJson(children[i], depth, max_depth, idx, info); + if(depth < max_depth) ans["childrens"][actions_str[i]] = child; + } + } + else if(type == GameTreeNode::CHANCE) { + if((++depth) <= max_depth) ans["node_type"] = "chance_node"; + shared_ptr chance_node = dynamic_pointer_cast(node); + shared_ptr children = chance_node->getChildren();// 不为null + int child_type = children->getType(); + n_act = chance_branch[GameTreeNode::gameRound2int(node->getRound())] + 4; + if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) {// 需要发1张牌 + if(depth <= max_depth) ans["deal_number"] = n_act; + if(depth < max_depth) ans["dealcards"] = json();// 需要展开子节点 + int j = decode_idx0(info), new_info = 0; + for(int i = 0, k = 0; i < n_act; i++, k++) {// 动作索引i,poss_card索引k + if(j == -1) new_info = code_idx0(k);// 第一次发牌 + else {// 第二次发牌,最多发两次牌 + if(k == j) k++;// 两次选的一样,则第二次改成下一个 + // new_info = code_idx0(max(j,k)) | code_idx1(min(j,k));// idx0为较大值 + } + json child = reConvertJson(children, depth, max_depth, idx, new_info); + if(depth < max_depth) ans["dealcards"][Card::intCard2Str(poss_card[k])] = child; + } + } + else { + n_act = n_act*(n_act-1)>>1; + idx += n_act; + if(depth <= max_depth) ans["deal_number"] = n_act; + } + } + // else {} + return std::move(ans); +} \ No newline at end of file From cbbdc4d8615b6dce68be7e343c91ae40b4e13e34 Mon Sep 17 00:00:00 2001 From: yffbit Date: Wed, 15 May 2024 22:57:00 +0800 Subject: [PATCH 09/19] fix min_card --- src/solver/slice_cfr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 52c17df..b5a7344 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -912,7 +912,7 @@ json SliceCFR::reConvertJson(const shared_ptr& node, int depth, in size_t n_hand = hand_size[player]; int *ptr = hand_card_ptr[player]; for(size_t i = 0; i < n_hand; i++) { - stt[Card::intCard2Str(ptr[i+n_hand])+Card::intCard2Str(ptr[i])] = strategy[i]; + stt[Card::intCard2Str(ptr[i+n_hand]+min_card)+Card::intCard2Str(ptr[i]+min_card)] = strategy[i]; } ans["strategy"]["strategy"] = std::move(stt); From ed84501e69d8b2b82dd62d711941d43e4675963a Mon Sep 17 00:00:00 2001 From: yffbit Date: Mon, 20 May 2024 23:08:04 +0800 Subject: [PATCH 10/19] cuda --- CMakeLists.txt | 13 +- include/ranges/RiverRangeManager.h | 4 + include/runtime/PokerSolver.h | 3 +- include/solver/cuda_cfr.h | 82 +++++++ include/solver/cuda_func.h | 23 ++ include/solver/slice_cfr.h | 33 +-- include/tools/CommandLineTool.h | 2 +- include/tools/utils.h | 14 +- src/runtime/PokerSolver.cpp | 56 +++-- src/solver/cuda_cfr.cu | 346 +++++++++++++++++++++++++++++ src/solver/cuda_func.cu | 277 +++++++++++++++++++++++ src/solver/slice_cfr.cpp | 23 +- 12 files changed, 816 insertions(+), 60 deletions(-) create mode 100644 include/solver/cuda_cfr.h create mode 100644 include/solver/cuda_func.h create mode 100644 src/solver/cuda_cfr.cu create mode 100644 src/solver/cuda_func.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 11f5d36..e072b9f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.20) -# project(TexasSolver LANGUAGES CXX CUDA) -project(TexasSolver LANGUAGES CXX) +project(TexasSolver LANGUAGES CXX CUDA) +# project(TexasSolver LANGUAGES CXX) set(CMAKE_CXX_STANDARD 20) # set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -10,7 +10,7 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_AUTORCC ON) set(CMAKE_AUTOUIC ON) -set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD 20) # set(CMAKE_CUDA_STANDARD_REQUIRED ON) message("${CMAKE_MINOR_VERSION}") if(${CMAKE_MINOR_VERSION} GREATER_EQUAL 24) @@ -23,6 +23,7 @@ else() endif() message("CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}") +message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}") if((DEFINED CMAKE_BUILD_TYPE) AND (CMAKE_BUILD_TYPE STREQUAL Debug)) set(CMAKE_CUDA_FLAGS "-g -G ${CMAKE_CUDA_FLAGS}") endif() @@ -50,7 +51,7 @@ list(REMOVE_ITEM SRC ${GUI_SRC} ${EXE_SRC} ${API_SRC}) # message("GUI_SRC=${GUI_SRC}") # message("API_SRC=${API_SRC}") # message("EXE_SRC=${EXE_SRC}") -# file(GLOB_RECURSE CUDA_SRC src/*.cu) +file(GLOB_RECURSE CUDA_SRC src/*.cu) message("CUDA_SRC=${CUDA_SRC}") set(BASE_LIB TexasSolver) @@ -65,7 +66,9 @@ target_link_libraries(${API_TARGET} PUBLIC ${BASE_LIB}) set(EXE console_solver) add_executable(${EXE} ${EXE_SRC}) target_link_libraries(${EXE} PRIVATE ${BASE_LIB}) -# target_link_options(${EXE} PUBLIC "/NODEFAULTLIB:LIBCMT") +if(MSVC) + target_link_options(${EXE} PUBLIC "/NODEFAULTLIB:LIBCMT") +endif() file(GLOB FORMS *.ui) file(GLOB RESOURCES *.qrc) diff --git a/include/ranges/RiverRangeManager.h b/include/ranges/RiverRangeManager.h index cf57d4f..7d29eee 100644 --- a/include/ranges/RiverRangeManager.h +++ b/include/ranges/RiverRangeManager.h @@ -18,6 +18,10 @@ class RiverRangeManager { RiverRangeManager(shared_ptr handEvaluator); const vector& getRiverCombos(int player, const vector& riverCombos, const vector& board); const vector& getRiverCombos(int player, const vector& riverCombos, uint64_t board_long); + void clear() { + p1RiverRanges.clear(); + p2RiverRanges.clear(); + } private: unordered_map> p1RiverRanges; unordered_map> p2RiverRanges; diff --git a/include/runtime/PokerSolver.h b/include/runtime/PokerSolver.h index 1a73126..c1c97ed 100644 --- a/include/runtime/PokerSolver.h +++ b/include/runtime/PokerSolver.h @@ -13,6 +13,7 @@ #include "include/solver/PCfrSolver.h" #include "include/library.h" #include "include/solver/slice_cfr.h" +#include "include/solver/cuda_cfr.h" #include #include using namespace std; @@ -46,7 +47,7 @@ class PokerSolver { bool use_isomorphism, int use_halffloats, int threads, - bool slice_cfr = false + int slice_cfr = 0 ); void stop(); long long estimate_tree_memory(string& p1_range, string& p2_range, string& board); diff --git a/include/solver/cuda_cfr.h b/include/solver/cuda_cfr.h new file mode 100644 index 0000000..90f07fe --- /dev/null +++ b/include/solver/cuda_cfr.h @@ -0,0 +1,82 @@ +#ifndef _CUDA_CFR_H_ +#define _CUDA_CFR_H_ + +#include +#include +#include +#include +#include "nodes/GameTreeNode.h" +#include "solver/PCfrSolver.h" +#include +#include +#include "cuda_runtime.h" +#include "solver/slice_cfr.h" + +#define LANE_SIZE 32 + +struct CudaLeafNode { + float val = 0;// fold:player0的收益*随机概率,sd:胜者收益*随机概率 + int offset_prob_sum = 0; + int offset_p0 = 0; + int offset_p1 = 0; + float *data_p0 = nullptr; + float *data_p1 = nullptr; + int *info = nullptr; +}; +struct SDNode { + float val = 0;// 胜者收益*随机概率 + int offset_prob_sum = 0; + int offset_p0 = 0; + int offset_p1 = 0; + float *data_p0 = nullptr; + float *data_p1 = nullptr; + int *strength_data = nullptr; +}; + +class CudaCFR : public SliceCFR { +public: + CudaCFR( + shared_ptr tree, + vector &range1, + vector &range2, + vector &initial_board, + shared_ptr compairer, + Deck &deck, + int train_step, + int print_interval, + float accuracy, + int n_thread + ):SliceCFR(tree, range1, range2, initial_board, compairer, deck, train_step, print_interval, accuracy, n_thread) {} + virtual ~CudaCFR(); + virtual size_t estimate_tree_size(); +protected: + int *dev_hand_card = nullptr; + int *dev_hand_card_ptr[N_PLAYER] {nullptr,nullptr}; + size_t *dev_hand_hash = nullptr; + size_t *dev_hand_hash_ptr[N_PLAYER] {nullptr,nullptr}; + int *dev_same_hand_idx = nullptr; + Node *dev_nodes = nullptr;// cuda内存地址 + CudaLeafNode *dev_leaf_node = nullptr;// cuda内存地址 + vector dev_data; + vector dev_strength; + float *dev_root_cfv = nullptr, *dev_prob_sum = nullptr; + virtual size_t init_memory(); + size_t init_player_node(); + size_t init_leaf_node(); + void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset); + size_t init_strength_table(); + virtual void step(int iter, int player, bool best_cfv=false); + virtual void leaf_cfv(int player); + int block_size(int size) {// ceil + return (size + LANE_SIZE - 1) / LANE_SIZE; + } + void clear_prob_sum(int len); + virtual void _reach_prob(int player, bool best_cfv=false); + virtual void _rm(int player, bool best_cfv=false); + virtual void clear_data(int player); + virtual void clear_root_cfv(); + virtual void post_process(); + virtual vector> get_avg_strategy(int idx); +}; + +#endif // _CUDA_CFR_H_ diff --git a/include/solver/cuda_func.h b/include/solver/cuda_func.h new file mode 100644 index 0000000..1c7cb4c --- /dev/null +++ b/include/solver/cuda_func.h @@ -0,0 +1,23 @@ +#ifndef _CUDA_FUNC_H_ +#define _CUDA_FUNC_H_ + +#include "cuda_runtime.h" + +extern __host__ __device__ void print_data(int *arr, int n); +extern __host__ __device__ void print_data(size_t *arr, int n); +extern __host__ __device__ void print_data(float *arr, int n); +extern __global__ void print_data_kernel(int *arr, int n); +extern __global__ void print_data_kernel(size_t *arr, int n); +extern __global__ void print_data_kernel(float *arr, int n); +extern __global__ void clear_data_kernel(Node *node, int size, int n_hand); +extern __global__ void rm_avg_kernel(Node *node, int size, int n_hand); +extern __global__ void rm_kernel(Node *node, int size, int n_hand); +extern __global__ void reach_prob_avg_kernel(Node *node, int size, int n_hand); +extern __global__ void reach_prob_kernel(Node *node, int size, int n_hand); +extern __global__ void fold_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *hand_card, size_t *hand_hash, int *same_hand_idx); +extern __global__ void sd_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *my_card, int *opp_card, int n_card); +extern __global__ void best_cfv_kernel(Node *node, int size, int n_hand); +extern __global__ void cfv_kernel(Node *node, int size, int n_hand); +extern __global__ void updata_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef); + +#endif // _CUDA_FUNC_H_ \ No newline at end of file diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index 19c328c..eadf8f7 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -43,6 +43,12 @@ using std::mutex; #define reach_prob_offset(n_act, n_hand, act_idx) (((n_act) * 3 + (act_idx)) * (n_hand)) #define reach_prob_to_cfv(n_act, n_hand) ((n_act) * (n_hand) * 3) +// 数组poss_card的索引[0,51]-->[1,52],8位二进制编码,最多选两个,占用高16位,低16位预留其他用途 +#define code_idx0(i) (((i)+1)<<24) +#define decode_idx0(x) (((x)>>24) - 1) +#define code_idx1(i) (((i)+1)<<16) +#define decode_idx1(x) ((((x)>>16)&0xff) - 1) + struct Node { int n_act = 0;// 动作数 int parent_offset = -1;// 本节点对应的父节点数据reach_prob的偏移量 @@ -94,19 +100,18 @@ class SliceCFR : public Solver { float accuracy, int n_thread ); - ~SliceCFR(); - size_t estimate_tree_size(); + virtual ~SliceCFR(); + virtual size_t estimate_tree_size(); void train(); vector exploitability(); void stop(); json dumps(bool with_status, int depth); vector>> get_strategy(shared_ptr node, vector cards); vector>> get_evs(shared_ptr node, vector cards); -private: +protected: atomic_bool stop_flag {false}; bool init_succ = false; int n_thread = 0; - int thread_per_block = 32; int steps = 0, interval = 0, n_card = N_CARD, min_card = 0; int init_round = 0; int dfs_idx = 0;// 先序遍历 @@ -143,20 +148,21 @@ class SliceCFR : public Solver { float *root_cfv_ptr[N_PLAYER] {nullptr,nullptr}; shared_ptr tree = nullptr; Deck& deck; + void init(); void init_hand_card(vector &range1, vector &range2); void init_hand_card(vector &range, vector &cards, vector &prob, size_t board, vector &out); void init_same_hand_idx(); void init_min_card(); - size_t init_memory(shared_ptr compairer); + virtual size_t init_memory(); size_t init_player_node(); size_t init_leaf_node(); void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset); void normalization(); - size_t init_strength_table(shared_ptr compairer); + size_t init_strength_table(); void dfs(shared_ptr node, int parent_act=-1, int parent_dfs_idx=-1, int parent_p0_act=-1, int parent_p0_idx=-1, int parent_p1_act=-1, int parent_p1_idx=-1, int cnt0=0, int cnt1=0, int info=0); void init_poss_card(Deck& deck, size_t board); - void step(int iter, int player, bool best_cfv=false); - void leaf_cfv(int player); + virtual void step(int iter, int player, bool best_cfv=false); + virtual void leaf_cfv(int player); void fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, size_t board); void sd_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, int idx); void append_node_idx(int p_idx, int act_idx, int player, int cpu_node_idx); @@ -172,12 +178,13 @@ class SliceCFR : public Solver { // int mtx_idx = N_PLAYER; vector> strength; size_t _estimate_tree_size(shared_ptr node); - void _reach_prob(int player, bool best_cfv=false); - void _rm(int player, bool best_cfv=false); - void clear_data(int player); - void clear_root_cfv(); + virtual void _reach_prob(int player, bool best_cfv=false); + virtual void _rm(int player, bool best_cfv=false); + virtual void clear_data(int player); + virtual void clear_root_cfv(); + virtual void post_process() {} json reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info); - vector> get_avg_strategy(int idx); + virtual vector> get_avg_strategy(int idx); }; #endif // _SLICE_CFR_H_ diff --git a/include/tools/CommandLineTool.h b/include/tools/CommandLineTool.h index a6482a5..d319c36 100644 --- a/include/tools/CommandLineTool.h +++ b/include/tools/CommandLineTool.h @@ -41,7 +41,7 @@ class CommandLineTool{ int max_iteration=100; int use_isomorphism=0; int print_interval=10; - bool slice_cfr = false; + int slice_cfr = 0; int dump_rounds = 1; shared_ptr gtbs; }; diff --git a/include/tools/utils.h b/include/tools/utils.h index 6009be1..6b5b70b 100644 --- a/include/tools/utils.h +++ b/include/tools/utils.h @@ -73,11 +73,17 @@ class Timer { void reset() { start = std::chrono::steady_clock::now(); } - int64_t ms() { - return std::chrono::duration_cast(std::chrono::steady_clock::now()-start).count(); + int64_t ms(bool reset=false) { + std::chrono::steady_clock::time_point curr = std::chrono::steady_clock::now(); + int64_t ans = std::chrono::duration_cast(curr-start).count(); + if(reset) start = curr; + return ans; } - int64_t us() { - return std::chrono::duration_cast(std::chrono::steady_clock::now()-start).count(); + int64_t us(bool reset=false) { + std::chrono::steady_clock::time_point curr = std::chrono::steady_clock::now(); + int64_t ans = std::chrono::duration_cast(curr-start).count(); + if(reset) start = curr; + return ans; } private: std::chrono::steady_clock::time_point start {}; diff --git a/src/runtime/PokerSolver.cpp b/src/runtime/PokerSolver.cpp index 1cf551f..07966f9 100644 --- a/src/runtime/PokerSolver.cpp +++ b/src/runtime/PokerSolver.cpp @@ -86,7 +86,7 @@ long long PokerSolver::estimate_tree_memory(string &p1_range, string &p2_range, } void PokerSolver::train(string p1_range, string p2_range, string boards, string log_file, int iteration_number, - int print_interval, string algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads, bool slice_cfr) { + int print_interval, string algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads, int slice_cfr) { string player1RangeStr = p1_range; string player2RangeStr = p2_range; @@ -105,31 +105,39 @@ void PokerSolver::train(string p1_range, string p2_range, string boards, string string logfile_name = log_file; if(solver) solver.reset();// 释放内存 - if(slice_cfr) { - solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads); + try { + if(slice_cfr == 1) { + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads); + } + else if(slice_cfr == 2) { + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads); + } + else { + solver = make_shared( + game_tree + , range1 + , range2 + , initialBoard + , compairer + , deck + , iteration_number + , false + , print_interval + , logfile_name + , algorithm + , Solver::MonteCarolAlg::NONE + , warmup + , accuracy + , use_isomorphism + , use_halffloats + , threads + ); + } + solver->train(); } - else { - solver = make_shared( - game_tree - , range1 - , range2 - , initialBoard - , compairer - , deck - , iteration_number - , false - , print_interval - , logfile_name - , algorithm - , Solver::MonteCarolAlg::NONE - , warmup - , accuracy - , use_isomorphism - , use_halffloats - , threads - ); + catch(std::exception& e) { + std::cerr << e.what() << '\n'; } - solver->train(); } void PokerSolver::dump_strategy(QString dump_file,int dump_rounds) { diff --git a/src/solver/cuda_cfr.cu b/src/solver/cuda_cfr.cu new file mode 100644 index 0000000..52bfc79 --- /dev/null +++ b/src/solver/cuda_cfr.cu @@ -0,0 +1,346 @@ +#include "solver/cuda_cfr.h" +#include "solver/cuda_func.h" +#include "ranges/RiverRangeManager.h" + +void cuda_error(cudaError_t error, const char *file, int line) { + if(error != cudaSuccess) { + printf("%s in %s at line %d\n", cudaGetErrorString(error), file, line); + exit(EXIT_FAILURE); + } +} + +#define CHECK_ERROR(error) (cuda_error(error, __FILE__, __LINE__)) + +template +void copy_to_device(T *dev, T *host, int n, bool print=false) { + if(!dev || !host || n <= 0) return; + size_t size = n * sizeof(T); + CHECK_ERROR(cudaMemcpy(dev, host, size, cudaMemcpyHostToDevice)); + if(!print) return; + print_data(host, n); + print_data_kernel<<<1, 1>>>(dev, n); + cudaDeviceSynchronize(); +} + +int max_malloc_len(int left, int right, int group_size = 1) { + int mid = 0, size = group_size * sizeof(float); + float *p = nullptr; + while(left < right) { + mid = (left + right + 1) >> 1;// 靠右 + if(cudaMalloc(&p, mid * size) == cudaSuccess) { + cudaFree(p); + left = mid; + } + else right = mid - 1; + } + return left; +} + +void CudaCFR::leaf_cfv(int player) { + Timer timer; + int opp = 1 - player, offset = player == P0 ? 0 : hand_size[P0]; + int my_hand = hand_size[player], opp_hand = hand_size[opp]; + int size = node_cnt[FOLD_TYPE]; + int block = block_size(size); + clear_prob_sum(size); + fold_cfv_kernel<<>>( + player, size, dev_leaf_node, dev_prob_sum, my_hand, opp_hand, + dev_hand_card_ptr[opp], dev_hand_hash_ptr[opp], dev_same_hand_idx+offset + ); + cudaDeviceSynchronize(); + // printf("fold_cfv:%zd ms\n", timer.ms(true)); + + size = node_cnt[SHOWDOWN_TYPE]; + block = block_size(size); + clear_prob_sum(size); + sd_cfv_kernel<<>>( + player, size, dev_leaf_node+sd_offset, dev_prob_sum, my_hand, opp_hand, + dev_hand_card_ptr[player], dev_hand_card_ptr[opp], n_card + ); + cudaDeviceSynchronize(); + // printf("sd_cfv:%zd ms\n", timer.ms()); +} + +CudaCFR::~CudaCFR() { + if(dev_root_cfv) CHECK_ERROR(cudaFree(dev_root_cfv)); + if(dev_hand_card) CHECK_ERROR(cudaFree(dev_hand_card)); + if(dev_hand_hash) CHECK_ERROR(cudaFree(dev_hand_hash)); + if(dev_nodes) CHECK_ERROR(cudaFree(dev_nodes)); + if(dev_leaf_node) CHECK_ERROR(cudaFree(dev_leaf_node)); + for(float *p : dev_data) { + if(p) CHECK_ERROR(cudaFree(p)); + } + if(dev_prob_sum) CHECK_ERROR(cudaFree(dev_prob_sum)); + for(int *p : dev_strength) { + if(p) CHECK_ERROR(cudaFree(p)); + } +} + +void CudaCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset) { + if(player == -1) player = node.player;// 向上连接同玩家节点 + int p_idx = node.parent_p0_idx, act_idx = node.parent_p0_act;// 向上连接P0 + if(player != P0) {// 向上连接P1 + p_idx = node.parent_p1_idx; + act_idx = node.parent_p1_act; + } + if(p_idx == -1) { + cfv = root_cfv_ptr[player]; + offset = root_prob_ptr[player] - root_cfv_ptr[player]; + } + else { + if(player != dfs_node[p_idx].player) throw runtime_error("player mismatch"); + cfv = dev_data[dfs_idx_map[p_idx]] + cfv_offset(hand_size[player], act_idx); + offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, hand_size[player]); + } +} + +size_t CudaCFR::init_player_node() { + size_t total = 0, size = 0, node_size = n_player_node * sizeof(Node); + vector cpu_node(n_player_node);// 与cuda内存对应 + CHECK_ERROR(cudaMalloc(&dev_nodes, node_size)); + total += node_size; + dev_data = vector(n_player_node, nullptr); + dfs_idx_map = vector(dfs_idx, -1); + slice_offset = vector>(N_PLAYER); + int mem_idx = 0; + for(int i = 0; i < N_PLAYER; i++) {// 枚举player + for(vector &nodes : slice[i]) {// 枚举slice + slice_offset[i].push_back(mem_idx); + for(int idx : nodes) {// 枚举node + DFSNode &node = dfs_node[idx]; + Node &target = cpu_node[mem_idx];// cpu存储位置 + target.n_act = node.n_act; + set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset); + size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); + CHECK_ERROR(cudaMalloc(&target.data, size)); + if(target.data == nullptr) throw runtime_error("malloc error"); + total += size; + dev_data[mem_idx] = target.data; + dfs_idx_map[idx] = mem_idx++; + } + } + slice_offset[i].push_back(mem_idx); + } + CHECK_ERROR(cudaMemcpy(dev_nodes, cpu_node.data(), node_size, cudaMemcpyHostToDevice)); + return total; +} + +size_t CudaCFR::init_leaf_node() { + size_t node_size = n_leaf_node * sizeof(CudaLeafNode); + vector cpu_node(n_leaf_node);// 与cuda内存对应 + CHECK_ERROR(cudaMalloc(&dev_leaf_node, node_size)); + int mem_idx = 0; + for(int t = 0; t < N_LEAF_TYPE; t++) { + for(int i = 0; i < leaf_node_dfs[t].size(); i++) { + DFSNode &node = dfs_node[leaf_node_dfs[t][i]]; + CudaLeafNode &target = cpu_node[mem_idx++];// cpu存储位置 + target.val = ev[t][i]; + target.offset_prob_sum = i * n_card; + set_cfv_and_offset(node, P0, target.data_p0, target.offset_p0); + set_cfv_and_offset(node, P1, target.data_p1, target.offset_p1); + int j = decode_idx0(node.info), k = decode_idx1(node.info); + size_t info = init_board; + if(t == FOLD_TYPE) { + if(j != -1) info |= 1LL << poss_card[j]; + if(k != -1) info |= 1LL << poss_card[k]; + target.info = (int *)info; + } + else { + if(j == -1) info = 0; + else if(k == -1) info = j; + else info = tril_idx(j, k); + target.info = dev_strength[info]; + } + } + } + CHECK_ERROR(cudaMemcpy(dev_leaf_node, cpu_node.data(), node_size, cudaMemcpyHostToDevice)); + sd_offset = leaf_node_dfs[FOLD_TYPE].size(); + ev.clear(); + return node_size; +} + +size_t CudaCFR::init_memory() { + size_t total = 0; + int n = root_prob.size(); + root_cfv = vector(n, 0); + size_t size = ((n << 1)) * sizeof(float);// cfv + prob + CHECK_ERROR(cudaMalloc(&dev_root_cfv, size)); + total += size; + root_cfv_ptr[P0] = dev_root_cfv; + root_cfv_ptr[P1] = dev_root_cfv + hand_size[P0]; + root_prob_ptr[P0] = root_cfv_ptr[P0] + n; + root_prob_ptr[P1] = root_prob_ptr[P0] + hand_size[P0]; + clear_root_cfv(); + copy_to_device(root_prob_ptr[P0], root_prob.data(), n); + + vector temp_hand_card = hand_card; + vector temp_hand_hash = hand_hash; + // [P0,P1,P0] + temp_hand_card.insert(temp_hand_card.end(), hand_card.begin(), hand_card.begin()+(hand_size[P0]<<1)); + temp_hand_hash.insert(temp_hand_hash.end(), hand_hash.begin(), hand_hash.begin()+hand_size[P0]); + n = temp_hand_card.size(); + size = (n + same_hand_idx.size()) * sizeof(int);// [P0,P1,P0] + [P0,P1] + CHECK_ERROR(cudaMalloc(&dev_hand_card, size)); + total += size; + copy_to_device(dev_hand_card, temp_hand_card.data(), n); + dev_same_hand_idx = dev_hand_card + n; + copy_to_device(dev_same_hand_idx, same_hand_idx.data(), same_hand_idx.size()); + + n = temp_hand_hash.size(); + size = n * sizeof(size_t); + CHECK_ERROR(cudaMalloc(&dev_hand_hash, size)); + total += size; + copy_to_device(dev_hand_hash, temp_hand_hash.data(), n); + dev_hand_card_ptr[P0] = dev_hand_card; + dev_hand_card_ptr[P1] = dev_hand_card + (hand_size[P0]<<1); + dev_hand_hash_ptr[P0] = dev_hand_hash; + dev_hand_hash_ptr[P1] = dev_hand_hash + hand_size[P0]; + + total += init_player_node(); + total += init_strength_table(); + total += init_leaf_node(); + + // FOLD_TYPE,SHOWDOWN_TYPE,共用dev_prob_sum + int len = max(node_cnt[FOLD_TYPE], node_cnt[SHOWDOWN_TYPE]); + size = len * n_card * sizeof(float); + CHECK_ERROR(cudaMalloc(&dev_prob_sum, size)); + total += size; + return total; +} + +size_t CudaCFR::init_strength_table() { + SliceCFR::init_strength_table(); + int n = strength.size(); + size_t total = 0, size = 0; + dev_strength = vector(n, nullptr); + for(int i = 0; i < n; i++) { + const RiverCombs *p0_comb = strength[i][P0].data, *p1_comb = strength[i][P1].data; + int p0_size = strength[i][P0].size, p1_size = strength[i][P1].size, d = 0; + vector data(2+((p0_size+p1_size)<<1)); + data[d++] = 2 + (p0_size<<1); + data[d++] = data.size(); + for(int j = 0; j < p0_size; j++) { + data[d++] = p0_comb[j].rank; + data[d++] = p0_comb[j].reach_prob_index; + } + for(int j = 0; j < p1_size; j++) { + data[d++] = p1_comb[j].rank; + data[d++] = p1_comb[j].reach_prob_index; + } + size = data.size() * sizeof(int); + CHECK_ERROR(cudaMalloc(&dev_strength[i], size)); + total += size; + copy_to_device(dev_strength[i], data.data(), data.size()); + } + strength.clear(); + rrm.clear(); + return total; +} + +size_t CudaCFR::estimate_tree_size() { + for(int i = 0; i < N_TYPE; i++) node_cnt[i] = 0; + if(tree == nullptr) return 0; + size_t size = _estimate_tree_size(tree->getRoot()); + n_leaf_node = node_cnt[FOLD_TYPE] + node_cnt[SHOWDOWN_TYPE]; + n_player_node = node_cnt[N_LEAF_TYPE+P0] + node_cnt[N_LEAF_TYPE+P1]; + size *= sizeof(float); + size += n_leaf_node * sizeof(CudaLeafNode); + size += n_player_node * sizeof(Node); + size += max(node_cnt[FOLD_TYPE], node_cnt[SHOWDOWN_TYPE]) * n_card * sizeof(float); + return size; +} + +void CudaCFR::_reach_prob(int player, bool best_cfv) { + vector& offset = slice_offset[player]; + int n = offset.size() - 1, size = 0, block = 0, n_hand = hand_size[player]; + for(int i = 0; i < n; i++) { + size = offset[i+1] - offset[i]; + block = block_size(size); + if(best_cfv) reach_prob_avg_kernel<<>>(dev_nodes+offset[i], size, n_hand); + else reach_prob_kernel<<>>(dev_nodes+offset[i], size, n_hand); + cudaDeviceSynchronize(); + } +} + +void CudaCFR::_rm(int player, bool best_cfv) { + int size = node_cnt[N_LEAF_TYPE + player]; + int block = block_size(size); + Node *node = dev_nodes + slice_offset[player][0]; + if(best_cfv) rm_avg_kernel<<>>(node, size, hand_size[player]); + else rm_kernel<<>>(node, size, hand_size[player]); + cudaDeviceSynchronize(); +} + +void CudaCFR::clear_data(int player) { + int size = node_cnt[N_LEAF_TYPE + player]; + int block = block_size(size); + clear_data_kernel<<>>(dev_nodes+slice_offset[player][0], size, hand_size[player]); + cudaDeviceSynchronize(); +} + +void CudaCFR::clear_prob_sum(int len) { + CHECK_ERROR(cudaMemset(dev_prob_sum, 0, len * n_card * sizeof(float))); + cudaDeviceSynchronize(); +} + +void CudaCFR::clear_root_cfv() { + CHECK_ERROR(cudaMemset(dev_root_cfv, 0, root_cfv.size() * sizeof(float))); + cudaDeviceSynchronize(); +} + +void CudaCFR::step(int iter, int player, bool best_cfv) { + Timer timer; + int opp = 1 - player, my_hand = hand_size[player], size = 0, block = 0; + _reach_prob(opp, best_cfv); + size_t t1 = timer.ms(true); + + leaf_cfv(player); + size_t t2 = timer.ms(true); + + if(!best_cfv) { + size = n_player_node; + block = block_size(size); + updata_data_kernel<<>>(dev_nodes, size, my_hand, pos_coef, neg_coef, coef); + cudaDeviceSynchronize(); + } + size_t t3 = timer.ms(true); + vector& offset = slice_offset[player]; + for(int i = offset.size()-2; i >= 0; i--) { + size = offset[i+1] - offset[i]; + block = block_size(size); + if(best_cfv) best_cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); + else cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); + cudaDeviceSynchronize(); + } + size_t t4 = timer.ms(); + printf("%zd\t%zd\t%zd\t%zd\n", t1, t2, t3, t4); +} + +void CudaCFR::post_process() { + int n = root_cfv.size(); + CHECK_ERROR(cudaMemcpy(root_cfv.data(), dev_root_cfv, n * sizeof(float), cudaMemcpyDeviceToHost)); + // print_data(root_cfv.data(), n); + // print_data_kernel<<<1, 1>>>(dev_root_cfv, n); + // cudaDeviceSynchronize(); +} + +vector> CudaCFR::get_avg_strategy(int idx) { + DFSNode &node = dfs_node[idx]; + int n_hand = hand_size[node.player], n_act = node.n_act; + int size = n_act * n_hand, i = 0, h = 0, j = 0; + float *dev = dev_data[dfs_idx_map[idx]] + (size << 1), sum = 0, uni = 1.0 / n_act; + vector strategy_sum(size);// [n_act,n_hand] + CHECK_ERROR(cudaMemcpy(strategy_sum.data(), dev, size * sizeof(float), cudaMemcpyDeviceToHost)); + vector> strategy(n_hand, vector(n_act));// [n_hand,n_act] + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += strategy_sum[i]; + if(sum == 0) { + for(j = 0; j < n_act; j++) strategy[h][j] = uni; + } + else { + for(j = 0, i = h; j < n_act; j++, i += n_hand) strategy[h][j] = strategy_sum[i] / sum; + } + } + return strategy; +} \ No newline at end of file diff --git a/src/solver/cuda_func.cu b/src/solver/cuda_func.cu new file mode 100644 index 0000000..0c02b6e --- /dev/null +++ b/src/solver/cuda_func.cu @@ -0,0 +1,277 @@ +#include "solver/cuda_cfr.h" +#include "solver/cuda_func.h" +#include "device_launch_parameters.h" + +__host__ __device__ void print_data(int *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%d", arr[0]); + for(int i = 1; i < n; i++) printf(",%d", arr[i]); + } + printf("\n"); +} +__host__ __device__ void print_data(size_t *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%llx", arr[0]); + for(int i = 1; i < n; i++) printf(",%llx", arr[i]); + } + printf("\n"); +} +__host__ __device__ void print_data(float *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%.2f", arr[0]); + for(int i = 1; i < n; i++) printf(",%.2f", arr[i]); + } + printf("\n"); +} +__global__ void print_data_kernel(int *arr, int n) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i == 0) print_data(arr, n); +} +__global__ void print_data_kernel(size_t *arr, int n) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i == 0) print_data(arr, n); +} +__global__ void print_data_kernel(float *arr, int n) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i == 0) print_data(arr, n); +} + +__global__ void clear_data_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = get_size(node->n_act, n_hand); + float *data = node->data; + for(i = 0; i < size; i++) data[i] = 0; +} + +// 不同节点之间独立 +__global__ void rm_avg_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + int h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += data[i]; + data[sum_offset+h] = sum; + } +} +__global__ void rm_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + int h = 0, sum_offset = size * 3; + float *data = node->data + size;// regret_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += max(0.0f, data[i]); + data[sum_offset+h] = sum; + } +} + +// 上层slice传递到本层slice +__global__ void reach_prob_avg_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + int n_act = node->n_act; + size = n_act * n_hand; + int h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[size+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[size+i] = temp * data[i]; + } + } +} +__global__ void reach_prob_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + int n_act = node->n_act; + size = n_act * n_hand; + int h = 0, rp_offset = size << 1, sum_offset = rp_offset + size; + float *data = node->data + size;// regret_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp * max(0.0f, data[i]); + } + } +} + +// 叶子节点向上层slice聚合,调用前需要清零上层slice的cfv +// same_hand_idx:player same_hand_idx +// hand_hash,hand_card:init opp [P0,P1,P0] +__global__ void fold_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *hand_card, size_t *hand_hash, int *same_hand_idx) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + opp_prob_sum += node->offset_prob_sum; + size_t board = (size_t)node->info; + float *cfv = nullptr, *opp_reach = nullptr, val = node->val; + float prob_sum = 0, temp = 0; + if(player == P0) { + cfv = node->data_p0, opp_reach = node->data_p1 + node->offset_p1; + } + else { + cfv = node->data_p1, opp_reach = node->data_p0 + node->offset_p0; + val = -val; + } + for(i = 0; i < opp_hand; i++) { + if(hand_hash[i] & board) continue;// 对方手牌与公共牌冲突 + temp = opp_reach[i]; + opp_prob_sum[hand_card[i]] += temp;// card1 + opp_prob_sum[hand_card[i+opp_hand]] += temp;// card2 + prob_sum += temp; + } + hand_hash += opp_hand;// ptr of player + hand_card += (opp_hand << 1); + for(i = 0; i < my_hand; i++) { + if(hand_hash[i] & board) { + // cfv[i] = 0;// 与公共牌冲突,cfv为0 + continue; + } + temp = same_hand_idx[i] != -1 ? opp_reach[same_hand_idx[i]] : 0;// 重复计算的部分 + temp = (prob_sum - opp_prob_sum[hand_card[i]] - opp_prob_sum[hand_card[i+my_hand]] + temp) * val; + atomicAdd(cfv+i, temp); + } +} + +// showdown +__global__ void sd_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *my_card, int *opp_card, int n_card) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return;// 总任务数 + node += i; + opp_prob_sum += node->offset_prob_sum; + float *cfv = nullptr, *opp_reach = nullptr; + float prob_sum = 0, temp = 0; + int j = 0, size_j = 0, h = 0, s = 0, *strength_data = node->info; + // strength_data:2+size0,2+size0+size1,sorted_data + // i,size for player + // j,size_j for opp + if(player == P0) { + i = 2, size_j = strength_data[1]; + size = j = strength_data[0]; + cfv = node->data_p0, opp_reach = node->data_p1 + node->offset_p1; + } + else { + j = 2, size = strength_data[1]; + size_j = i = strength_data[0]; + cfv = node->data_p1, opp_reach = node->data_p0 + node->offset_p0; + } + // strength_data += 2; + for(; i < size; i += 2) {// strength值变小,己方手牌变强 + s = strength_data[i]; + for(; j < size_j && strength_data[j] > s; j += 2) {// (胜过对方条件下)找到对方的最强手牌 + h = strength_data[j+1]; + temp = opp_reach[h]; + opp_prob_sum[opp_card[h]] += temp;// card1 + opp_prob_sum[opp_card[h+opp_hand]] += temp;// card2 + prob_sum += temp; + } + h = strength_data[i+1]; + temp = (prob_sum - opp_prob_sum[my_card[h]] - opp_prob_sum[my_card[h+my_hand]]) * node->val; + atomicAdd(cfv+h, temp); + } + prob_sum = 0; + for(h = 0; h < n_card; h++) opp_prob_sum[h] = 0; + i -= 2, j -= 2; + if(player == P0) { + size_j = size; + size = 2; + } + else { + size = size_j; + size_j = 2; + } + for(; i >= size; i -= 2) {// strength值变大,己方手牌变弱 + s = strength_data[i]; + for(; j >= size_j && strength_data[j] < s; j -= 2) {// (败给对方条件下)找到对方的最弱手牌 + h = strength_data[j+1]; + temp = opp_reach[h]; + opp_prob_sum[opp_card[h]] += temp;// card1 + opp_prob_sum[opp_card[h+opp_hand]] += temp;// card2 + prob_sum += temp; + } + h = strength_data[i+1]; + temp = (opp_prob_sum[my_card[h]] + opp_prob_sum[my_card[h+my_hand]] - prob_sum) * node->val; + atomicAdd(cfv+h, temp); + } +} + +// 本层slice向上层slice聚合,上层cfv需要先清零 +// 子节点cfv中选最大值 +__global__ void best_cfv_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + int h = 0; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + for(h = 0; h < n_hand; h++) { + val = cfv[h];// 第一个 + for(i = h+n_hand; i < size; i += n_hand) val = max(val, cfv[i]); + atomicAdd(parent_cfv+h, val); + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} +// 子节点cfv加权求和 +__global__ void cfv_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + int n_act = node->n_act; + size = n_act * n_hand; + int h = 0, sum_offset = size << 2; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + float *regret_sum = cfv + size; + for(h = 0; h < n_hand; h++) { + val = 0; + if(cfv[sum_offset+h] == 0) { + for(i = h; i < size; i += n_hand) val += cfv[i]; + val /= n_act;// uniform strategy + } + else { + for(i = h; i < size; i += n_hand) { + val += cfv[i] * max(0.0f, regret_sum[i]); + } + val /= cfv[sum_offset+h]; + } + atomicAdd(parent_cfv+h, val); + for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum + val = 0; + for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); + cfv[sum_offset+h] = val;// 求和 + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} + +__global__ void updata_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + float *regret_sum = node->data + size, *strategy_sum = regret_sum + size; + for(i = 0; i < size; i++) { + regret_sum[i] *= regret_sum[i] > 0 ? pos_coef : neg_coef; + strategy_sum[i] = strategy_sum[i] * coef + strategy_sum[size+i]; + } +} diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index b5a7344..0f2aa9e 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -5,12 +5,6 @@ using std::memory_order_relaxed; using std::atomic_ref; -// 数组poss_card的索引[0,51]-->[1,52],8位二进制编码,最多选两个,占用高16位,低16位预留其他用途 -#define code_idx0(i) (((i)+1)<<24) -#define decode_idx0(x) (((x)>>24) - 1) -#define code_idx1(i) (((i)+1)<<16) -#define decode_idx1(x) ((((x)>>16)&0xff) - 1) - void print_array(int *arr, int n) { if(arr != nullptr && n > 0) { printf("%d", arr[0]); @@ -102,6 +96,7 @@ void best_cfv_up(Node *node, int n_hand) { // mtx->unlock(); atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv } // 子节点cfv加权求和 void cfv_up(Node *node, int n_hand) { @@ -310,7 +305,7 @@ size_t SliceCFR::init_leaf_node() { printf("%zd,%zd,%zd,%zd\n", min_val[P0], max_val[P0], min_val[P1], max_val[P1]); ev[FOLD_TYPE].insert(ev[FOLD_TYPE].end(), ev[SHOWDOWN_TYPE].begin(), ev[SHOWDOWN_TYPE].end()); - ev[FOLD_TYPE].clear(); + ev[SHOWDOWN_TYPE].clear(); ev_ptr = ev[FOLD_TYPE].data(); size_t total = n_leaf_node * sizeof(LeafNode); total += (pre_leaf_node[P0].size() + pre_leaf_node[P1].size()) * sizeof(PreLeafNode); @@ -343,7 +338,9 @@ SliceCFR::SliceCFR( if(this->n_thread == 0) this->n_thread = omp_get_num_procs(); omp_set_num_threads(this->n_thread); // test_parallel_for(this->n_thread); - +} + +void SliceCFR::init() { float unit = 1 << 20; size_t size = estimate_tree_size(); printf("estimate memory:%f MB\n", size/unit); @@ -369,7 +366,7 @@ SliceCFR::SliceCFR( assert(node_cnt[N_LEAF_TYPE+CHANCE_PLAYER] == chance_node.size()); if(dfs_idx == 0 || dfs_node[0].n_act == 0) return; - size = init_memory(compairer); + size = init_memory(); printf("%d nodes, total:%f MB\n", dfs_idx, size/unit); init_succ = true; } @@ -440,7 +437,7 @@ size_t SliceCFR::init_player_node() { return total; } -size_t SliceCFR::init_memory(shared_ptr compairer) { +size_t SliceCFR::init_memory() { size_t total = 0; int n = root_prob.size(); root_cfv = vector(n<<1, 0); @@ -453,11 +450,11 @@ size_t SliceCFR::init_memory(shared_ptr compairer) { total += init_player_node(); total += init_leaf_node(); - total += init_strength_table(compairer); + total += init_strength_table(); return total; } -size_t SliceCFR::init_strength_table(shared_ptr compairer) { +size_t SliceCFR::init_strength_table() { int n = poss_card.size(); vector board_hash; if(init_round == RIVER_ROUND) board_hash.push_back(init_board); @@ -736,6 +733,7 @@ void SliceCFR::clear_root_cfv() { } void SliceCFR::train() { + init(); if(!init_succ) return; size_t start = timeSinceEpochMillisec(), total = 0; Timer timer; @@ -849,6 +847,7 @@ vector SliceCFR::exploitability() { printf("rm time:%zd\t%zd\n", t1, t2); #endif } + post_process(); _reach_prob(P0, false);// 恢复P0的reach_prob,用于下一次迭代 int m = 0, n = hand_size[P0]; float ev0 = 0, ev1 = 0; From 06bea956813be94de0ca862d8888a919d1ddd49d Mon Sep 17 00:00:00 2001 From: yffbit Date: Wed, 29 May 2024 11:44:42 +0800 Subject: [PATCH 11/19] decoupling between console and qt --- .github/workflows/main.yml | 4 +- CMakeLists.txt | 169 ++++++------ TexasSolverGui.pro | 4 + benchmark/texassolver.txt | 13 +- boardselector.cpp | 6 +- boardselector.h | 4 +- include/Card.h | 14 +- include/card_format.h | 11 + include/library.h | 2 +- include/runtime/PokerSolver.h | 30 ++- include/runtime/qsolverjob.h | 19 +- include/solver/BestResponse.h | 3 +- include/solver/PCfrSolver.h | 2 +- include/solver/Solver.h | 5 +- include/solver/cuda_cfr.h | 11 +- include/solver/slice_cfr.h | 5 +- include/tools/CommandLineTool.h | 83 ++++-- include/tools/GameTreeBuildingSettings.h | 3 +- include/tools/StreetSetting.h | 3 +- include/tools/logger.h | 43 ++++ include/ui/tablestrategymodel.h | 1 + mainwindow.cpp | 155 +++++++++-- mainwindow.h | 38 +++ rangeselector.cpp | 6 +- rangeselector.h | 4 +- src/Card.cpp | 12 +- src/api.cpp | 15 +- src/card_format.cpp | 32 +++ src/compairer/Dic5Compairer.cpp | 33 ++- src/console.cpp | 18 +- src/library.cpp | 2 +- src/nodes/GameActions.cpp | 4 +- src/runtime/PokerSolver.cpp | 65 +++-- src/runtime/qsolverjob.cpp | 100 +++++--- src/solver/BestResponse.cpp | 15 +- src/solver/PCfrSolver.cpp | 32 ++- src/solver/Solver.cpp | 2 +- src/solver/slice_cfr.cpp | 25 +- src/tools/CommandLineTool.cpp | 312 +++++++++++++++-------- src/tools/GameTreeBuildingSettings.cpp | 2 +- src/tools/logger.cpp | 47 ++++ src/ui/boardselectortabledelegate.cpp | 3 +- src/ui/detailitemdelegate.cpp | 18 +- src/ui/strategyitemdelegate.cpp | 2 +- src/ui/treemodel.cpp | 4 +- strategyexplorer.cpp | 10 +- 46 files changed, 969 insertions(+), 422 deletions(-) create mode 100644 include/card_format.h create mode 100644 include/tools/logger.h create mode 100644 src/card_format.cpp create mode 100644 src/tools/logger.cpp diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9a56d84..d552122 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,7 +38,9 @@ jobs: - name: make appimage run: | ls - ./build-AppImage.sh + # ./build-AppImage.sh + cmake -DCMAKE_BUILD_TYPE=Release -S . -B build + make -C build -j - uses: actions/upload-artifact@v2 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index e072b9f..351234f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,31 +1,20 @@ cmake_minimum_required(VERSION 3.20) -project(TexasSolver LANGUAGES CXX CUDA) -# project(TexasSolver LANGUAGES CXX) +option(USE_CUDA "" OFF) +option(QT_GUI "" ON) +option(BUILD_API "" ON) -set(CMAKE_CXX_STANDARD 20) -# set(CMAKE_CXX_STANDARD_REQUIRED ON) - -# set(CMAKE_AUTOMOC ON) -set(CMAKE_AUTORCC ON) -set(CMAKE_AUTOUIC ON) - -set(CMAKE_CUDA_STANDARD 20) -# set(CMAKE_CUDA_STANDARD_REQUIRED ON) -message("${CMAKE_MINOR_VERSION}") -if(${CMAKE_MINOR_VERSION} GREATER_EQUAL 24) - # set(CMAKE_CUDA_ARCHITECTURES all) - set(CMAKE_CUDA_ARCHITECTURES all-major) - # set(CMAKE_CUDA_ARCHITECTURES native) +if(USE_CUDA) + project(TexasSolver LANGUAGES CXX CUDA) else() - set(CMAKE_CUDA_ARCHITECTURES OFF) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=all-major") + project(TexasSolver LANGUAGES CXX) endif() -message("CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}") -message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}") -if((DEFINED CMAKE_BUILD_TYPE) AND (CMAKE_BUILD_TYPE STREQUAL Debug)) - set(CMAKE_CUDA_FLAGS "-g -G ${CMAKE_CUDA_FLAGS}") +set(CMAKE_CXX_STANDARD 20) +# set(CMAKE_CXX_STANDARD_REQUIRED ON) +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") + message("CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}") endif() set(CMAKE_INCLUDE_CURRENT_DIR ON) @@ -34,34 +23,57 @@ include_directories(include) find_package(OpenMP REQUIRED) message("OpenMP_CXX_FLAGS=${OpenMP_CXX_FLAGS}") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler ${OpenMP_CXX_FLAGS}") -message("CMAKE_CUDA_FLAGS=${CMAKE_CUDA_FLAGS}") - -find_package(QT NAMES Qt6 Qt5 REQUIRED COMPONENTS Widgets) -set(QT_MAJOR Qt${QT_VERSION_MAJOR}) -message("QT_MAJOR=${QT_MAJOR}") -find_package(${QT_MAJOR} REQUIRED COMPONENTS Core Widgets LinguistTools) - file(GLOB_RECURSE SRC src/*.cpp) +file(GLOB_RECURSE EXC_SRC src/*format.cpp) file(GLOB GUI_SRC *.cpp src/ui/*.cpp src/runtime/qsolverjob.cpp) file(GLOB API_SRC src/api.cpp) file(GLOB EXE_SRC src/console.cpp) -list(REMOVE_ITEM SRC ${GUI_SRC} ${EXE_SRC} ${API_SRC}) +list(REMOVE_ITEM SRC ${EXC_SRC} ${GUI_SRC} ${EXE_SRC} ${API_SRC}) # message("SRC=${SRC}") +# message("EXC_SRC=${EXC_SRC}") # message("GUI_SRC=${GUI_SRC}") # message("API_SRC=${API_SRC}") # message("EXE_SRC=${EXE_SRC}") -file(GLOB_RECURSE CUDA_SRC src/*.cu) -message("CUDA_SRC=${CUDA_SRC}") + +if(USE_CUDA) + add_definitions(-DUSE_CUDA) + file(GLOB_RECURSE CUDA_SRC src/*.cu) + message("CUDA_SRC=${CUDA_SRC}") + + set(CMAKE_CUDA_STANDARD 20) + # set(CMAKE_CUDA_STANDARD_REQUIRED ON) + message("CMAKE_MINOR_VERSION=${CMAKE_MINOR_VERSION}") + if(${CMAKE_MINOR_VERSION} GREATER_EQUAL 24) + # set(CMAKE_CUDA_ARCHITECTURES all) + set(CMAKE_CUDA_ARCHITECTURES all-major) + # set(CMAKE_CUDA_ARCHITECTURES native) + else() + set(CMAKE_CUDA_ARCHITECTURES OFF) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=all-major") + endif() + message("CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}") + + message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}") + if((DEFINED CMAKE_BUILD_TYPE) AND (CMAKE_BUILD_TYPE STREQUAL Debug)) + set(CMAKE_CUDA_FLAGS "-g -G ${CMAKE_CUDA_FLAGS}") + endif() + + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler ${OpenMP_CXX_FLAGS}") + message("CMAKE_CUDA_FLAGS=${CMAKE_CUDA_FLAGS}") +endif() set(BASE_LIB TexasSolver) add_library(${BASE_LIB} ${SRC} ${CUDA_SRC}) -target_link_libraries(${BASE_LIB} PUBLIC ${QT_MAJOR}::Core OpenMP::OpenMP_CXX) +target_link_libraries(${BASE_LIB} PUBLIC OpenMP::OpenMP_CXX) +if(USE_CUDA) # set_target_properties(${BASE_LIB} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) +endif() -set(API_TARGET api) -add_library(${API_TARGET} SHARED ${API_SRC}) -target_link_libraries(${API_TARGET} PUBLIC ${BASE_LIB}) +if(BUILD_API) + set(API_TARGET api) + add_library(${API_TARGET} SHARED ${API_SRC}) + target_link_libraries(${API_TARGET} PUBLIC ${BASE_LIB}) +endif() set(EXE console_solver) add_executable(${EXE} ${EXE_SRC}) @@ -70,40 +82,51 @@ if(MSVC) target_link_options(${EXE} PUBLIC "/NODEFAULTLIB:LIBCMT") endif() -file(GLOB FORMS *.ui) -file(GLOB RESOURCES *.qrc) -file(GLOB TS_FILES *.ts) -file(GLOB QM_FILES *.qm) -# message("FORMS=${FORMS}") -# message("RESOURCES=${RESOURCES}") -# message("TS_FILES=${TS_FILES}") -# message("QM_FILES=${QM_FILES}") - -SET(ICON_NAME texassolver_logo) -if(WIN32) - file(GLOB ICON_FILE imgs/${ICON_NAME}.rc) -elseif(APPLE) - set(MACOSX_BUNDLE_ICON_FILE ${ICON_NAME}.icns) - file(GLOB ICON_FILE imgs/${ICON_NAME}.icns) - set_source_files_properties(${ICON_FILE} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources") +if(QT_GUI) + file(GLOB FORMS *.ui) + file(GLOB RESOURCES *.qrc) + file(GLOB TS_FILES *.ts) + file(GLOB QM_FILES *.qm) + # message("FORMS=${FORMS}") + # message("RESOURCES=${RESOURCES}") + # message("TS_FILES=${TS_FILES}") + # message("QM_FILES=${QM_FILES}") + + # set(CMAKE_AUTOMOC ON) + set(CMAKE_AUTORCC ON) + set(CMAKE_AUTOUIC ON) + + find_package(QT NAMES Qt6 Qt5 REQUIRED COMPONENTS Widgets) + set(QT_MAJOR Qt${QT_VERSION_MAJOR}) + message("QT_MAJOR=${QT_MAJOR}") + find_package(${QT_MAJOR} REQUIRED COMPONENTS Core Widgets LinguistTools) + + SET(ICON_NAME texassolver_logo) + if(WIN32) + file(GLOB ICON_FILE imgs/${ICON_NAME}.rc) + elseif(APPLE) + set(MACOSX_BUNDLE_ICON_FILE ${ICON_NAME}.icns) + file(GLOB ICON_FILE imgs/${ICON_NAME}.icns) + set_source_files_properties(${ICON_FILE} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources") + endif() + # message("ICON_FILE=${ICON_FILE}") + + # set(CMAKE_AUTOMOC ON) doesn't work + # Q_OBJECT header + file(GLOB HEADERS *.h include/ui/*.h include/runtime/qsolverjob.h) + # message("HEADERS=${HEADERS}") + if(${QT_VERSION_MAJOR} GREATER_EQUAL 6) + qt6_wrap_cpp(GUI_SRC ${HEADERS}) + else() + qt5_wrap_cpp(GUI_SRC ${HEADERS}) + endif() + # message("GUI_SRC=${GUI_SRC}") + + set(GUI TexasSolverGui) + add_executable(${GUI} ${GUI_SRC} ${EXC_SRC} ${RESOURCES} ${FORMS} ${ICON_FILE}) + target_link_libraries(${GUI} PRIVATE ${QT_MAJOR}::Widgets ${QT_MAJOR}::Core ${BASE_LIB}) + set_target_properties(${GUI} PROPERTIES + WIN32_EXECUTABLE ON + MACOSX_BUNDLE ON + ) endif() -# message("ICON_FILE=${ICON_FILE}") - -# set(CMAKE_AUTOMOC ON) doesn't work -# Q_OBJECT header -file(GLOB HEADERS *.h include/ui/*.h include/runtime/qsolverjob.h) -# message("HEADERS=${HEADERS}") -if(${QT_VERSION_MAJOR} GREATER_EQUAL 6) - qt6_wrap_cpp(GUI_SRC ${HEADERS}) -else() - qt5_wrap_cpp(GUI_SRC ${HEADERS}) -endif() -# message("GUI_SRC=${GUI_SRC}") - -set(GUI TexasSolverGui) -add_executable(${GUI} ${GUI_SRC} ${RESOURCES} ${FORMS} ${ICON_FILE}) -target_link_libraries(${GUI} PRIVATE ${QT_MAJOR}::Widgets ${BASE_LIB}) -set_target_properties(${GUI} PROPERTIES - WIN32_EXECUTABLE ON - MACOSX_BUNDLE ON -) diff --git a/TexasSolverGui.pro b/TexasSolverGui.pro index d93539c..1ca9f44 100644 --- a/TexasSolverGui.pro +++ b/TexasSolverGui.pro @@ -65,6 +65,7 @@ SOURCES += \ mainwindow.cpp \ src/Deck.cpp \ src/Card.cpp \ + src/card_format.cpp \ # src/console.cpp \ src/GameTree.cpp \ src/library.cpp \ @@ -95,6 +96,7 @@ SOURCES += \ src/tools/Rule.cpp \ src/tools/StreetSetting.cpp \ src/tools/utils.cpp \ + src/tools/logger.cpp \ src/trainable/CfrPlusTrainable.cpp \ src/trainable/DiscountedCfrTrainable.cpp \ src/trainable/DiscountedCfrTrainableHF.cpp \ @@ -131,6 +133,7 @@ HEADERS += \ include/trainable/DiscountedCfrTrainableSF.h \ mainwindow.h \ include/Card.h \ + include/card_format.h \ include/GameTree.h \ include/Deck.h \ include/json.hpp \ @@ -167,6 +170,7 @@ HEADERS += \ include/ranges/RiverCombs.h \ include/ranges/RiverRangeManager.h \ include/tools/tinyformat.h \ + include/tools/logger.h \ include/tools/qdebugstream.h \ include/runtime/qsolverjob.h \ qstextedit.h \ diff --git a/benchmark/texassolver.txt b/benchmark/texassolver.txt index 9d37c5f..748dec1 100644 --- a/benchmark/texassolver.txt +++ b/benchmark/texassolver.txt @@ -1,9 +1,9 @@ set_pot 10 set_effective_stack 95 -#set_board Qs,Jh,2h,4d +set_board Qs,Jh,2h,4d #set_range_oop AA,KK,QQ,JJ #set_range_ip QQ:0.5,JJ:0.75 -set_board Qs,Jh,2h +#set_board Qs,Jh,2h set_range_oop AA,KK,QQ,JJ,TT,99:0.75,88:0.75,77:0.5,66:0.25,55:0.25,AK,AQs,AQo:0.75,AJs,AJo:0.5,ATs:0.75,A6s:0.25,A5s:0.75,A4s:0.75,A3s:0.5,A2s:0.5,KQs,KQo:0.5,KJs,KTs:0.75,K5s:0.25,K4s:0.25,QJs:0.75,QTs:0.75,Q9s:0.5,JTs:0.75,J9s:0.75,J8s:0.75,T9s:0.75,T8s:0.75,T7s:0.75,98s:0.75,97s:0.75,96s:0.5,87s:0.75,86s:0.5,85s:0.5,76s:0.75,75s:0.5,65s:0.75,64s:0.5,54s:0.75,53s:0.5,43s:0.5 set_range_ip QQ:0.5,JJ:0.75,TT,99,88,77,66,55,44,33,22,AKo:0.25,AQs,AQo:0.75,AJs,AJo:0.75,ATs,ATo:0.75,A9s,A8s,A7s,A6s,A5s,A4s,A3s,A2s,KQ,KJ,KTs,KTo:0.5,K9s,K8s,K7s,K6s,K5s,K4s:0.5,K3s:0.5,K2s:0.5,QJ,QTs,Q9s,Q8s,Q7s,JTs,JTo:0.5,J9s,J8s,T9s,T8s,T7s,98s,97s,96s,87s,86s,76s,75s,65s,64s,54s,53s,43s set_bet_sizes oop,flop,bet,100 @@ -31,11 +31,12 @@ build_tree estimate_tree_memory set_thread_num 6 #set_thread_num 81920 -set_slice_cfr 1 +set_slice_cfr 0 set_accuracy 0.3 -set_max_iteration 2000 +set_max_iteration 1 set_print_interval 10 -#set_use_isomorphism 1 +set_use_isomorphism 1 start_solve set_dump_rounds 1 -#dump_result output_result.json +dump_result output_result2.json +#dump_setting output_setting.txt diff --git a/boardselector.cpp b/boardselector.cpp index 6250eb9..82f1db3 100644 --- a/boardselector.cpp +++ b/boardselector.cpp @@ -1,7 +1,7 @@ #include "boardselector.h" #include "ui_boardselector.h" -boardselector::boardselector(QTextEdit* boardEdit,QSolverJob::Mode mode,QWidget *parent) : +boardselector::boardselector(QTextEdit* boardEdit,PokerMode mode,QWidget *parent) : QDialog(parent), ui(new Ui::boardselector) { @@ -11,9 +11,9 @@ boardselector::boardselector(QTextEdit* boardEdit,QSolverJob::Mode mode,QWidget this->mode = mode; QString ranks; - if(mode == QSolverJob::Mode::HOLDEM){ + if(mode == PokerMode::HOLDEM){ ranks = "A,K,Q,J,T,9,8,7,6,5,4,3,2"; - }else if(mode == QSolverJob::Mode::SHORTDECK){ + }else if(mode == PokerMode::SHORTDECK){ ranks = "A,K,Q,J,T,9,8,7,6"; }else{ throw runtime_error("mode not found in range selector"); diff --git a/boardselector.h b/boardselector.h index 99d9a5c..bb1ee8d 100644 --- a/boardselector.h +++ b/boardselector.h @@ -18,7 +18,7 @@ class boardselector : public QDialog Q_OBJECT public: - explicit boardselector(QTextEdit* boardEdit,QSolverJob::Mode mode = QSolverJob::Mode::HOLDEM,QWidget *parent = 0); + explicit boardselector(QTextEdit* boardEdit,PokerMode mode = PokerMode::HOLDEM,QWidget *parent = 0); ~boardselector(); private slots: @@ -37,7 +37,7 @@ private slots: private: Ui::boardselector *ui; QTextEdit* boardEdit = NULL; - QSolverJob::Mode mode; + PokerMode mode; QStringList rank_list; BoardSelectorTableModel * boardSelectorTableModel = NULL; BoardSelectorTableDelegate * boardSelectorTableDelegate = NULL; diff --git a/include/Card.h b/include/Card.h index 7b52585..a64cd8f 100644 --- a/include/Card.h +++ b/include/Card.h @@ -8,7 +8,7 @@ #include #include #include "include/tools/tinyformat.h" -#include +// #include using namespace std; class Card { @@ -20,17 +20,17 @@ class Card { Card(); explicit Card(string card,int card_number_in_deck); Card(string card); - string getCard(); + const string& getCard(); int getCardInt(); bool empty(); int getNumberInDeckInt(); static int card2int(Card card); - static int strCard2int(string card); + static int strCard2int(const string &card); static string intCard2Str(int card); static uint64_t boardCards2long(vector cards); static uint64_t boardCard2long(Card& card); static uint64_t boardCards2long(vector& cards); - static QString boardCards2html(vector& cards); + // static QString boardCards2html(vector& cards); static inline bool boardsHasIntercept(uint64_t board1,uint64_t board2){ return ((board1 & board2) != 0); }; @@ -43,9 +43,9 @@ class Card { static int rankToInt(char rank); static int suitToInt(char suit); static vector getSuits(); - string toString(); - string toFormattedString(); - QString toFormattedHtml(); + // string toString(); + // string toFormattedString(); + // QString toFormattedHtml(); }; #endif //TEXASSOLVER_CARD_H diff --git a/include/card_format.h b/include/card_format.h new file mode 100644 index 0000000..83aed98 --- /dev/null +++ b/include/card_format.h @@ -0,0 +1,11 @@ +#if !defined(_CARD_FORMAT_H_) +#define _CARD_FORMAT_H_ + +#include +#include "include/Card.h" + +string toFormattedString(Card &card); +QString toFormattedHtml(Card &card); +QString boardCards2html(vector& cards); + +#endif // _CARD_FORMAT_H_ diff --git a/include/library.h b/include/library.h index cb958f8..59f3d7b 100644 --- a/include/library.h +++ b/include/library.h @@ -78,7 +78,7 @@ Combinations::comb(unsigned long long n, unsigned long long k) { return r; } -vector string_split(string strin,char split); +vector string_split(string &strin, char split); uint64_t timeSinceEpochMillisec(); int random(int min, int max); float normalization_tanh(float stack,float ev,float ratio=7); diff --git a/include/runtime/PokerSolver.h b/include/runtime/PokerSolver.h index c1c97ed..f67487f 100644 --- a/include/runtime/PokerSolver.h +++ b/include/runtime/PokerSolver.h @@ -13,15 +13,21 @@ #include "include/solver/PCfrSolver.h" #include "include/library.h" #include "include/solver/slice_cfr.h" -#include "include/solver/cuda_cfr.h" -#include -#include +// #include +// #include using namespace std; +enum PokerMode { + HOLDEM, + SHORTDECK, + UNKNOWN +}; + class PokerSolver { public: - PokerSolver(); - PokerSolver(string ranks,string suits,string compairer_file,int compairer_file_lines,string compairer_file_bin); + PokerSolver() {} + PokerSolver(PokerMode mode, string &resource_dir); + PokerSolver(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin); void load_game_tree(string game_tree_file); void build_game_tree( float oop_commit, @@ -35,13 +41,13 @@ class PokerSolver { float allin_threshold ); void train( - string p1_range, - string p2_range, - string boards, - string log_file, + string &p1_range, + string &p2_range, + string &boards, + // string &log_file, int iteration_number, int print_interval, - string algorithm, + string &algorithm, int warmup, float accuracy, bool use_isomorphism, @@ -53,11 +59,13 @@ class PokerSolver { long long estimate_tree_memory(string& p1_range, string& p2_range, string& board); vector player1Range; vector player2Range; - void dump_strategy(QString dump_file,int dump_rounds); + void dump_strategy(string &dump_file, int dump_rounds); shared_ptr get_game_tree(){return this->game_tree;}; Deck* get_deck(){return &this->deck;} shared_ptr get_solver(){return this->solver;} + Logger *logger = nullptr; private: + void init(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin); shared_ptr compairer; Deck deck; shared_ptr game_tree; diff --git a/include/runtime/qsolverjob.h b/include/runtime/qsolverjob.h index 507bb05..c701ee6 100644 --- a/include/runtime/qsolverjob.h +++ b/include/runtime/qsolverjob.h @@ -16,11 +16,11 @@ class QSolverJob : public QThread private: QSTextEdit * textEdit; public: - enum Mode{ - HOLDEM, - SHORTDECK - }; - Mode mode = Mode::HOLDEM; + // enum Mode{ + // HOLDEM, + // SHORTDECK + // }; + PokerMode mode = PokerMode::HOLDEM; enum MissionType{ LOADING, @@ -31,6 +31,7 @@ class QSolverJob : public QThread MissionType current_mission = MissionType::LOADING; string resource_dir; PokerSolver ps_holdem,ps_shortdeck; + /* float oop_commit=5; float ip_commit=5; int current_round=1; @@ -50,7 +51,9 @@ class QSolverJob : public QThread int print_interval=10; int dump_rounds = 2; shared_ptr gtbs; - + */ + CommandLineTool *clt = nullptr; + Logger *logger = nullptr; PokerSolver* get_solver(); void run(); void loading(); @@ -58,8 +61,8 @@ class QSolverJob : public QThread void stop(); void saving(); void build_tree(); - long long estimate_tree_memory(QString range1,QString range2,QString board); + long long estimate_tree_memory(string &range1, string &range2, string &board); void setContext(QSTextEdit * textEdit); - QString savefile; + // QString savefile; }; #endif // QSOLVERJOB_H diff --git a/include/solver/BestResponse.h b/include/solver/BestResponse.h index a161a8a..efc86e3 100644 --- a/include/solver/BestResponse.h +++ b/include/solver/BestResponse.h @@ -16,6 +16,7 @@ #include #include #include +#include "include/tools/logger.h" using namespace std; @@ -47,7 +48,7 @@ class BestResponse { ); float printExploitability(shared_ptr root, int iterationCount, float initial_pot, uint64_t initialBoard); float getBestReponseEv(shared_ptr node, int player,vector> reach_probs, uint64_t initialBoard,int deal); - + Logger *logger = nullptr; private: vector bestResponse(shared_ptr node, int player, const vector>& reach_probs, uint64_t board,int deal); vector chanceBestReponse(shared_ptr node, int player, const vector>& reach_probs, uint64_t current_board,int deal); diff --git a/include/solver/PCfrSolver.h b/include/solver/PCfrSolver.h index cce4f68..7017afb 100644 --- a/include/solver/PCfrSolver.h +++ b/include/solver/PCfrSolver.h @@ -83,7 +83,7 @@ class PCfrSolver:public Solver { int iteration_number, bool debug, int print_interval, - string logfile, + /*string logfile*/Logger *logger, string trainer, Solver::MonteCarolAlg monteCarolAlg, int warmup, diff --git a/include/solver/Solver.h b/include/solver/Solver.h index d7f7271..2cfd9c5 100644 --- a/include/solver/Solver.h +++ b/include/solver/Solver.h @@ -5,7 +5,7 @@ #ifndef TEXASSOLVER_SOLVER_H #define TEXASSOLVER_SOLVER_H - +#include "include/tools/logger.h" #include class Solver { @@ -15,7 +15,7 @@ class Solver { PUBLIC }; Solver(); - Solver(shared_ptr tree); + Solver(shared_ptr tree, Logger *logger); shared_ptr getTree(); virtual void train() = 0; virtual void stop() = 0; @@ -23,6 +23,7 @@ class Solver { virtual vector>> get_strategy(shared_ptr node,vector cards) = 0; virtual vector>> get_evs(shared_ptr node,vector cards) = 0; shared_ptr tree; + Logger *logger = nullptr; }; diff --git a/include/solver/cuda_cfr.h b/include/solver/cuda_cfr.h index 90f07fe..41ac019 100644 --- a/include/solver/cuda_cfr.h +++ b/include/solver/cuda_cfr.h @@ -5,12 +5,12 @@ #include #include #include -#include "nodes/GameTreeNode.h" -#include "solver/PCfrSolver.h" +#include "include/nodes/GameTreeNode.h" +#include "include/solver/PCfrSolver.h" #include #include #include "cuda_runtime.h" -#include "solver/slice_cfr.h" +#include "include/solver/slice_cfr.h" #define LANE_SIZE 32 @@ -45,8 +45,9 @@ class CudaCFR : public SliceCFR { int train_step, int print_interval, float accuracy, - int n_thread - ):SliceCFR(tree, range1, range2, initial_board, compairer, deck, train_step, print_interval, accuracy, n_thread) {} + int n_thread, + Logger *logger + ):SliceCFR(tree, range1, range2, initial_board, compairer, deck, train_step, print_interval, accuracy, n_thread, logger) {} virtual ~CudaCFR(); virtual size_t estimate_tree_size(); protected: diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index eadf8f7..8c6e503 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -98,7 +98,8 @@ class SliceCFR : public Solver { int train_step, int print_interval, float accuracy, - int n_thread + int n_thread, + Logger *logger ); virtual ~SliceCFR(); virtual size_t estimate_tree_size(); @@ -146,7 +147,7 @@ class SliceCFR : public Solver { vector root_cfv, root_prob;// P0_cfv,P1_cfv,P0_prob,P1_prob float *root_prob_ptr[N_PLAYER] {nullptr,nullptr}; float *root_cfv_ptr[N_PLAYER] {nullptr,nullptr}; - shared_ptr tree = nullptr; + // shared_ptr tree = nullptr; Deck& deck; void init(); void init_hand_card(vector &range1, vector &range2); diff --git a/include/tools/CommandLineTool.h b/include/tools/CommandLineTool.h index d319c36..e302120 100644 --- a/include/tools/CommandLineTool.h +++ b/include/tools/CommandLineTool.h @@ -13,23 +13,62 @@ using namespace std; class CommandLineTool{ public: - CommandLineTool(string mode,string resource_dir); - void startWorking(); - void execFromFile(string input_file); - void processCommand(string input); -private: - enum Mode{ - HOLDEM, - SHORTDECK - }; - Mode mode; - string resource_dir; - PokerSolver ps; + CommandLineTool(); + void startWorking(PokerSolver *ps); + void execFromFile(const char *input_file, PokerSolver *ps); + void processCommand(string &input, PokerSolver *ps); + void dump_setting(const char *file); + void set_pot(float val) { + ip_commit = oop_commit = val / 2; + } + float get_pot() { + return ip_commit + oop_commit; + } + void set_effective_stack(float val) { + stack = val + ip_commit; + } + float get_effective_stack() { + return stack - ip_commit; + } + bool set_board(string &str); + bool set_bet_sizes(string &str, char delimiter = ',', vector *sizes = nullptr); + void build_tree(PokerSolver *ps) { + if(!ps) return; + ps->build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,gtbs,allin_threshold); + } + void start_solve(PokerSolver *ps) { + if(!ps) return; + // cout << "<<>>" << endl; + logger->log("<<>>"); + ps->train( + range_ip, + range_oop, + board, + // "tmp_log.txt", + max_iteration, + print_interval, + algorithm, + -1, + accuracy, + use_isomorphism, + use_halffloats, + thread_num, + slice_cfr + ); + } +// private: + // enum Mode{ + // HOLDEM, + // SHORTDECK + // }; + // Mode mode; + // string resource_dir; + // PokerSolver ps; float oop_commit=5; float ip_commit=5; int current_round=1; int raise_limit=4; - int thread_number=1; + int thread_num=1; float small_blind=0.5; float big_blind=1; float stack=20 + 5; @@ -37,13 +76,27 @@ class CommandLineTool{ string range_ip; string range_oop; string board; + string res_file; + string algorithm = "discounted_cfr"; float accuracy; int max_iteration=100; - int use_isomorphism=0; + bool use_isomorphism=0; + int use_halffloats=0; int print_interval=10; int slice_cfr = 0; int dump_rounds = 1; - shared_ptr gtbs; + GameTreeBuildingSettings gtbs; + Logger *logger = nullptr; }; +void split(const string& s, char delimiter, vector& v); +void join(const vector &vec, char delimiter, string &out); + +template +string tostring(T val); +template +string tostring_oss(T val); + +int cmd_api(string &input_file, string &resource_dir, string &mode, string &log_file); + #endif //BINDSOLVER_COMMANDLINETOOL_H diff --git a/include/tools/GameTreeBuildingSettings.h b/include/tools/GameTreeBuildingSettings.h index c72da89..394d193 100644 --- a/include/tools/GameTreeBuildingSettings.h +++ b/include/tools/GameTreeBuildingSettings.h @@ -8,6 +8,7 @@ class GameTreeBuildingSettings { public: + GameTreeBuildingSettings() {} GameTreeBuildingSettings( StreetSetting flop_ip, StreetSetting turn_ip, @@ -21,7 +22,7 @@ class GameTreeBuildingSettings { StreetSetting flop_oop; StreetSetting turn_oop; StreetSetting river_oop; - StreetSetting& get_setting(string player,string round); + StreetSetting& get_setting(string &player, string &round); }; #endif //BINDSOLVER_GAMETREEBUILDINGSETTINGS_H diff --git a/include/tools/StreetSetting.h b/include/tools/StreetSetting.h index 6aa1e2b..9c4721c 100644 --- a/include/tools/StreetSetting.h +++ b/include/tools/StreetSetting.h @@ -12,8 +12,9 @@ class StreetSetting { vector bet_sizes; vector raise_sizes; vector donk_sizes; - bool allin; + bool allin = true; + StreetSetting() {} StreetSetting(vector bet_sizes, vector raise_sizes, vector donk_sizes, bool allin); }; diff --git a/include/tools/logger.h b/include/tools/logger.h new file mode 100644 index 0000000..ad17ac9 --- /dev/null +++ b/include/tools/logger.h @@ -0,0 +1,43 @@ +#if !defined(_LOGGER_H_) +#define _LOGGER_H_ + +#include +#include +#include +#include + +using std::string; + +void get_localtime(char *buf, size_t n, const char *format); +string get_localtime(); + +class Logger { +public: + Logger(bool cmd, const char *path, const char *mode = "w", bool timestamp = false, bool new_line = true, int period = 10) + :cmd(cmd), timestamp(timestamp), new_line(new_line), period(period) { + if(path) { + errno_t err = fopen_s(&file, path, mode); + if(err) printf("%d\n", err); + if(!file) printf("create file %s failed\n", path); + } + } + virtual ~Logger() { + if(file) { + fflush(file); + fclose(file); + } + } + virtual void log(const char *format, ...); + void flush() { + if(file) fflush(file); + } +protected: + void log_time(); + int step = 0, period = 10; + FILE *file = nullptr; + bool timestamp = false; + bool cmd = true; + bool new_line = true; +}; + +#endif // _LOGGER_H_ diff --git a/include/ui/tablestrategymodel.h b/include/ui/tablestrategymodel.h index ede362a..44f098f 100644 --- a/include/ui/tablestrategymodel.h +++ b/include/ui/tablestrategymodel.h @@ -12,6 +12,7 @@ #include "include/ui/treeitem.h" #include "include/nodes/GameActions.h" #include +#include "include/card_format.h" class TableStrategyModel : public QAbstractItemModel { diff --git a/mainwindow.cpp b/mainwindow.cpp index ec81aa4..07a8ec7 100644 --- a/mainwindow.cpp +++ b/mainwindow.cpp @@ -18,17 +18,21 @@ MainWindow::MainWindow(QWidget *parent) : connect(this->ui->actionimport, &QAction::triggered, this, &MainWindow::on_actionimport_triggered); connect(this->ui->actionexport, &QAction::triggered, this, &MainWindow::on_actionexport_triggered); connect(this->ui->actionclear_all, &QAction::triggered, this, &MainWindow::on_actionclear_all_triggered); + logger = new QLogger((get_localtime() + ".txt").c_str(), "w", false, 1); + clt.logger = logger; qSolverJob = new QSolverJob; + qSolverJob->clt = &clt; qSolverJob->setContext(this->getLogArea()); + qSolverJob->logger = logger; qSolverJob->current_mission = QSolverJob::MissionType::LOADING; qSolverJob->start(); this->setWindowTitle(tr("TexasSolver")); - // parameters tree view QStringList filters; filters << "*.txt"; qFileSystemModel = new QFileSystemModel(this); - QDir filedir = QDir::current().filePath("parameters"); + QDir filedir = QDir::current()/*.filePath("parameters")*/; + logger->log(filedir.absolutePath().toLocal8Bit()); qFileSystemModel->setRootPath(filedir.path()); #ifdef Q_OS_MAC filedir = QDir(""); @@ -63,6 +67,11 @@ MainWindow::MainWindow(QWidget *parent) : this->ui->oopRangeTableView->verticalHeader()->setMinimumSectionSize(1); this->ui->oopRangeTableView->horizontalHeader()->setMinimumSectionSize(1); this->ui->tabWidget->hide(); + + show_tree_params(); + show_solver_params(); + this->update(); + update_range_ui(); } QSTextEdit * MainWindow::get_logwindow(){ @@ -78,6 +87,7 @@ MainWindow::~MainWindow() delete oop_delegate; delete oop_model; delete ui; + if(logger) delete logger; } void MainWindow::on_actionjson_triggered(){ @@ -85,7 +95,11 @@ void MainWindow::on_actionjson_triggered(){ "output_strategy.json", tr("Json file (*.json)")); if(fileName.isNull())return; - this->qSolverJob->savefile = fileName; + QSettings setting("TexasSolver", "Setting"); + setting.beginGroup("solver"); + clt.dump_rounds = setting.value("dump_round").toInt(); + clt.res_file = (const char*)fileName.toLocal8Bit(); + // this->qSolverJob->savefile = fileName; qSolverJob->current_mission = QSolverJob::MissionType::SAVING; qSolverJob->start(); } @@ -100,10 +114,7 @@ QString getParams(QString input,QString key){ void MainWindow::on_actionclear_all_triggered(){ this->clear_all_params(); - this->ui->IpRangeTableView->update(); - this->ui->oopRangeTableView->update(); - this->ui->IpRangeTableView->setFocus(); - this->ui->oopRangeTableView->setFocus(); + update_range_ui(); } void MainWindow::clear_all_params(){ @@ -147,6 +158,7 @@ void MainWindow::import_from_file(QString fileName){ qDebug().noquote() << tr("File selection invalid."); return; } + /* QFile file(fileName); if(!file.open(QIODevice::ReadOnly)){ qDebug().noquote() << tr("File open failed."); @@ -261,7 +273,10 @@ void MainWindow::import_from_file(QString fileName){ this->ui->useIsoCheck->setChecked(false); } } - } + }*/ + clt.execFromFile(fileName.toLocal8Bit(), nullptr); + show_tree_params(); + show_solver_params(); this->update(); } @@ -272,6 +287,9 @@ void MainWindow::on_actionimport_triggered(){ QDir::currentPath(), tr("Text files (*.txt)")); this->import_from_file(fileName); + update_range_ui(); +} +void MainWindow::update_range_ui() { this->ui->IpRangeTableView->update(); this->ui->oopRangeTableView->update(); this->ui->IpRangeTableView->setFocus(); @@ -283,7 +301,8 @@ void MainWindow::on_actionexport_triggered(){ "parameters/output_parameters.txt", tr("Text file (*.txt)")); if(fileName.isNull())return; - QString output_text = ""; + clt.dump_setting(fileName.toLocal8Bit()); + /*QString output_text = ""; QTextStream out(&output_text); out << "set_pot " << this->ui->potText->text().trimmed(); out << "\n"; @@ -398,6 +417,7 @@ void MainWindow::on_actionexport_triggered(){ msgBox.setText(message); setlocale(LC_CTYPE, "C"); msgBox.exec(); + */ } void MainWindow::on_actionSettings_triggered(){ @@ -412,10 +432,15 @@ void MainWindow::on_ip_range(QString range_text){ void MainWindow::on_buttomSolve_clicked() { + /* qSolverJob->max_iteration = ui->iterationText->text().toInt(); qSolverJob->accuracy = ui->exploitabilityText->text().toFloat(); qSolverJob->print_interval = ui->logIntervalText->text().toInt(); qSolverJob->thread_number = ui->threadsText->text().toInt(); + */ + get_solver_params(); + show_solver_params(); + this->update(); qSolverJob->current_mission = QSolverJob::MissionType::SOLVING; qSolverJob->start(); } @@ -451,6 +476,7 @@ vector sizes_convert(QString input){ void MainWindow::on_buildTreeButtom_clicked() { + /* qSolverJob->range_ip = this->ui->ipRangeText->toPlainText().toStdString(); qSolverJob->range_oop = this->ui->oopRangeText->toPlainText().toStdString(); qSolverJob->board = this->ui->boardText->toPlainText().toStdString(); @@ -470,7 +496,7 @@ void MainWindow::on_buildTreeButtom_clicked() qSolverJob->ip_commit = this->ui->potText->text().toFloat() / 2; qSolverJob->oop_commit = this->ui->potText->text().toFloat() / 2; qSolverJob->stack = this->ui->effectiveStackText->text().toFloat() + qSolverJob->ip_commit; - qSolverJob->mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + qSolverJob->mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM : PokerMode::SHORTDECK; qSolverJob->allin_threshold = this->ui->allinThresholdText->text().toFloat(); qSolverJob->use_isomorphism = this->ui->useIsoCheck->isChecked(); qSolverJob->use_halffloats = this->ui->useHalfFloats_box->currentIndex(); @@ -508,10 +534,103 @@ void MainWindow::on_buildTreeButtom_clicked() ); qSolverJob->gtbs = make_shared(gbs_flop_ip,gbs_turn_ip,gbs_river_ip,gbs_flop_oop,gbs_turn_oop,gbs_river_oop); + */ + get_tree_params(); + show_tree_params(); + this->update(); + this->ui->IpRangeTableView->update(); + this->ui->oopRangeTableView->update(); qSolverJob->current_mission = QSolverJob::MissionType::BUILDTREE; qSolverJob->start(); } +void MainWindow::get_tree_params() { + qSolverJob->mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM : PokerMode::SHORTDECK; + string val = this->ui->boardText->toPlainText().toStdString(); + if(!clt.set_board(val)) { + qDebug().noquote() << tfm::format("Error : board %s not recognized", val).c_str(); + return; + } + clt.range_ip = this->ui->ipRangeText->toPlainText().toStdString(); + clt.range_oop = this->ui->oopRangeText->toPlainText().toStdString(); + clt.raise_limit = this->ui->raiseLimitText->text().toInt(); + clt.set_pot(this->ui->potText->text().toFloat()); + clt.set_effective_stack(this->ui->effectiveStackText->text().toFloat()); + clt.allin_threshold = this->ui->allinThresholdText->text().toFloat(); + + set_bet_sizes(ui->flop_ip_bet, &clt.gtbs.flop_ip.bet_sizes); + set_bet_sizes(ui->flop_ip_raise, &clt.gtbs.flop_ip.raise_sizes); + clt.gtbs.flop_ip.allin = ui->flop_ip_allin->isChecked(); + set_bet_sizes(ui->turn_ip_bet, &clt.gtbs.turn_ip.bet_sizes); + set_bet_sizes(ui->turn_ip_raise, &clt.gtbs.turn_ip.raise_sizes); + clt.gtbs.turn_ip.allin = ui->turn_ip_allin->isChecked(); + set_bet_sizes(ui->river_ip_bet, &clt.gtbs.river_ip.bet_sizes); + set_bet_sizes(ui->river_ip_raise, &clt.gtbs.river_ip.raise_sizes); + clt.gtbs.river_ip.allin = ui->river_ip_allin->isChecked(); + + set_bet_sizes(ui->flop_oop_bet, &clt.gtbs.flop_oop.bet_sizes); + set_bet_sizes(ui->flop_oop_raise, &clt.gtbs.flop_oop.raise_sizes); + clt.gtbs.flop_oop.allin = ui->flop_oop_allin->isChecked(); + set_bet_sizes(ui->turn_oop_bet, &clt.gtbs.turn_oop.bet_sizes); + set_bet_sizes(ui->turn_oop_raise, &clt.gtbs.turn_oop.raise_sizes); + set_bet_sizes(ui->turn_oop_donk, &clt.gtbs.turn_oop.donk_sizes); + clt.gtbs.turn_oop.allin = ui->turn_oop_allin->isChecked(); + set_bet_sizes(ui->river_oop_bet, &clt.gtbs.river_oop.bet_sizes); + set_bet_sizes(ui->river_oop_raise, &clt.gtbs.river_oop.raise_sizes); + set_bet_sizes(ui->river_oop_donk, &clt.gtbs.river_oop.donk_sizes); + clt.gtbs.river_oop.allin = ui->river_oop_allin->isChecked(); +} + +void MainWindow::get_solver_params() { + clt.use_isomorphism = this->ui->useIsoCheck->isChecked(); + clt.use_halffloats = this->ui->useHalfFloats_box->currentIndex(); + clt.max_iteration = ui->iterationText->text().toInt(); + clt.accuracy = ui->exploitabilityText->text().toFloat(); + clt.print_interval = ui->logIntervalText->text().toInt(); + clt.thread_num = ui->threadsText->text().toInt(); +} + +void MainWindow::show_tree_params() { + ui->boardText->setText(clt.board.c_str()); + ui->ipRangeText->setText(clt.range_ip.c_str()); + ui->oopRangeText->setText(clt.range_oop.c_str()); + ui->raiseLimitText->setText(QString::number(clt.raise_limit)); + ui->potText->setText(QString::number(clt.get_pot())); + ui->effectiveStackText->setText(QString::number(clt.get_effective_stack())); + ui->allinThresholdText->setText(QString::number(clt.allin_threshold)); + + show_bet_sizes(ui->flop_ip_bet, clt.gtbs.flop_ip.bet_sizes); + show_bet_sizes(ui->flop_ip_raise, clt.gtbs.flop_ip.raise_sizes); + ui->flop_ip_allin->setChecked(clt.gtbs.flop_ip.allin); + show_bet_sizes(ui->turn_ip_bet, clt.gtbs.turn_ip.bet_sizes); + show_bet_sizes(ui->turn_ip_raise, clt.gtbs.turn_ip.raise_sizes); + ui->turn_ip_allin->setChecked(clt.gtbs.turn_ip.allin); + show_bet_sizes(ui->river_ip_bet, clt.gtbs.river_ip.bet_sizes); + show_bet_sizes(ui->river_ip_raise, clt.gtbs.river_ip.raise_sizes); + ui->river_ip_allin->setChecked(clt.gtbs.river_ip.allin); + + show_bet_sizes(ui->flop_oop_bet, clt.gtbs.flop_oop.bet_sizes); + show_bet_sizes(ui->flop_oop_raise, clt.gtbs.flop_oop.raise_sizes); + ui->flop_oop_allin->setChecked(clt.gtbs.flop_oop.allin); + show_bet_sizes(ui->turn_oop_bet, clt.gtbs.turn_oop.bet_sizes); + show_bet_sizes(ui->turn_oop_raise, clt.gtbs.turn_oop.raise_sizes); + show_bet_sizes(ui->turn_oop_donk, clt.gtbs.turn_oop.donk_sizes); + ui->turn_oop_allin->setChecked(clt.gtbs.turn_oop.allin); + show_bet_sizes(ui->river_oop_bet, clt.gtbs.river_oop.bet_sizes); + show_bet_sizes(ui->river_oop_raise, clt.gtbs.river_oop.raise_sizes); + show_bet_sizes(ui->river_oop_donk, clt.gtbs.river_oop.donk_sizes); + ui->river_oop_allin->setChecked(clt.gtbs.river_oop.allin); +} + +void MainWindow::show_solver_params() { + ui->useIsoCheck->setChecked(clt.use_isomorphism); + ui->useHalfFloats_box->setCurrentIndex(clt.use_halffloats); + ui->iterationText->setText(QString::number(clt.max_iteration)); + ui->exploitabilityText->setText(QString::number(clt.accuracy)); + ui->logIntervalText->setText(QString::number(clt.print_interval)); + ui->threadsText->setText(QString::number(clt.thread_num)); +} + void MainWindow::on_copyButtom_clicked() { ui->flop_oop_bet->setText(ui->flop_ip_bet->text()); @@ -543,7 +662,7 @@ void MainWindow::on_stopSolvingButton_clicked() void MainWindow::on_ipRangeSelectButtom_clicked() { - QSolverJob::Mode mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + PokerMode mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM:PokerMode::SHORTDECK; this->rangeSelector = new RangeSelector(this->ui->ipRangeText,this,mode); rangeSelector->setAttribute(Qt::WA_DeleteOnClose); rangeSelector->show(); @@ -551,14 +670,15 @@ void MainWindow::on_ipRangeSelectButtom_clicked() void MainWindow::on_oopRangeSelectButtom_clicked() { - QSolverJob::Mode mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + PokerMode mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM:PokerMode::SHORTDECK; this->rangeSelector = new RangeSelector(this->ui->oopRangeText,this,mode); rangeSelector->setAttribute(Qt::WA_DeleteOnClose); rangeSelector->show(); } float iso_corh(QString board){ - vector board_str_arr = string_split(board.toStdString(),','); + string board_str = board.toStdString(); + vector board_str_arr = string_split(board_str, ','); vector initialBoard; for(string one_board_str:board_str_arr){ initialBoard.push_back(Card(one_board_str)); @@ -584,7 +704,7 @@ float iso_corh(QString board){ void MainWindow::on_estimateMemoryButtom_clicked() { - long long memory_float = this->qSolverJob->estimate_tree_memory(this->ui->ipRangeText->toPlainText(),this->ui->oopRangeText->toPlainText(),this->ui->boardText->toPlainText()); + long long memory_float = this->qSolverJob->estimate_tree_memory(clt.range_ip, clt.range_oop, clt.board); // float32 should take 4bytes float corh = 1; if(this->ui->useIsoCheck->isChecked()){ @@ -620,7 +740,7 @@ void MainWindow::on_estimateMemoryButtom_clicked() void MainWindow::on_selectBoardButton_clicked() { - QSolverJob::Mode mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + PokerMode mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM:PokerMode::SHORTDECK; this->boardSelector = new boardselector(this->ui->boardText,mode,this); boardSelector->setAttribute(Qt::WA_DeleteOnClose); boardSelector->show(); @@ -636,10 +756,7 @@ void MainWindow::item_clicked(const QModelIndex& index){ QFileInfo fileinfo = QFileInfo(this->qFileSystemModel->filePath(index)); if(fileinfo.suffix() == "txt"){ this->import_from_file(this->qFileSystemModel->filePath(index)); - this->ui->IpRangeTableView->update(); - this->ui->oopRangeTableView->update(); - this->ui->IpRangeTableView->setFocus(); - this->ui->oopRangeTableView->setFocus(); + update_range_ui(); } } } diff --git a/mainwindow.h b/mainwindow.h index f6ee1fd..8efce20 100644 --- a/mainwindow.h +++ b/mainwindow.h @@ -13,6 +13,28 @@ #include "settingeditor.h" #include "include/ui/rangeselectortablemodel.h" #include "include/ui/rangeselectortabledelegate.h" +#include + +class QLogger : public Logger { +public: + QLogger(const char *path, const char *mode = "w", bool timestamp = false, int period = 10):Logger(false, path, mode, timestamp, true, period) {} + virtual void log(const char *format, ...) { + if(timestamp) log_time(); + va_list args = nullptr; + va_start(args, format); + if(file) { + vfprintf(file, format, args); + if((++step) == period) { + step = 0; + fflush(file); + } + if(new_line) fprintf(file, "\n"); + } + // qDebug().noquote() << QString::vasprintf(QObject::tr(format).toLocal8Bit(), args); + qDebug().noquote() << QString::vasprintf(QObject::tr(format).toStdString().c_str(), args); + va_end(args); + } +}; namespace Ui { class MainWindow; @@ -60,8 +82,24 @@ private slots: private: void clear_all_params(); + void get_tree_params(); + void get_solver_params(); + void show_tree_params(); + void show_solver_params(); + void set_bet_sizes(QLineEdit *edit, vector *sizes) { + string s = edit->text().toStdString(); + clt.set_bet_sizes(s, ' ', sizes); + } + void show_bet_sizes(QLineEdit *edit, vector &sizes) { + string s; + join(sizes, ' ', s); + edit->setText(s.c_str()); + } + void update_range_ui(); Ui::MainWindow *ui = NULL; QSolverJob* qSolverJob = NULL; + CommandLineTool clt; + Logger *logger = nullptr; QFileSystemModel * qFileSystemModel = NULL; StrategyExplorer* strategyExplorer = NULL; RangeSelector* rangeSelector = NULL; diff --git a/rangeselector.cpp b/rangeselector.cpp index 8e81a72..24db7c2 100644 --- a/rangeselector.cpp +++ b/rangeselector.cpp @@ -1,15 +1,15 @@ #include "rangeselector.h" #include "ui_rangeselector.h" -RangeSelector::RangeSelector(QTextEdit* rangeEdit,QWidget *parent,QSolverJob::Mode mode) : +RangeSelector::RangeSelector(QTextEdit* rangeEdit,QWidget *parent,PokerMode mode) : QDialog(parent), ui(new Ui::RangeSelector) { QString ranks; - if(mode == QSolverJob::Mode::HOLDEM){ + if(mode == PokerMode::HOLDEM){ ranks = "A,K,Q,J,T,9,8,7,6,5,4,3,2"; - }else if(mode == QSolverJob::Mode::SHORTDECK){ + }else if(mode == PokerMode::SHORTDECK){ ranks = "A,K,Q,J,T,9,8,7,6"; }else{ throw runtime_error("mode not found in range selector"); diff --git a/rangeselector.h b/rangeselector.h index 447e370..40438ad 100644 --- a/rangeselector.h +++ b/rangeselector.h @@ -24,14 +24,14 @@ class RangeSelector : public QDialog Q_OBJECT public: - explicit RangeSelector(QTextEdit* rangeEdit,QWidget *parent = 0,QSolverJob::Mode mode = QSolverJob::Mode::HOLDEM); + explicit RangeSelector(QTextEdit* rangeEdit,QWidget *parent = 0,PokerMode mode = PokerMode::HOLDEM); ~RangeSelector(); signals: void confirm_text(QString content); private: int max_val = 1000; float range_num = 1; - QSolverJob::Mode mode; + PokerMode mode; Ui::RangeSelector *ui; QStringList rank_list; RangeSelectorTableModel * rangeSelectorTableModel = NULL; diff --git a/src/Card.cpp b/src/Card.cpp index ca91ac1..8e9f04d 100644 --- a/src/Card.cpp +++ b/src/Card.cpp @@ -22,7 +22,7 @@ bool Card::empty(){ else return false; } -string Card::getCard() { +const string& Card::getCard() { return this->card; } @@ -39,7 +39,7 @@ int Card::card2int(Card card) { return strCard2int(card.getCard()); } -int Card::strCard2int(string card) { +int Card::strCard2int(const string &card) { char rank = card.at(0); char suit = card.at(1); if(card.length() != 2){ @@ -74,14 +74,14 @@ uint64_t Card::boardCards2long(vector& cards){ return Card::boardInts2long(board_int); } -QString Card::boardCards2html(vector& cards){ +/*QString Card::boardCards2html(vector& cards){ QString ret_html = ""; for(auto one_card:cards){ if(one_card.empty())continue; ret_html += one_card.toFormattedHtml(); } return ret_html; -} +}*/ uint64_t Card::boardInt2long(int board){ // 这里hard code了一副扑克牌是52张 @@ -217,7 +217,7 @@ vector Card::getSuits(){ return {"c","d","h","s"}; } -string Card::toString() { +/*string Card::toString() { return this->card; } @@ -241,4 +241,4 @@ QString Card::toFormattedHtml() { else if(qString.contains("s")) qString = qString.replace("s", QString::fromLocal8Bit("♠<\/span>")); return qString; -} +}*/ diff --git a/src/api.cpp b/src/api.cpp index 5c9879c..1d28a08 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -12,20 +12,11 @@ #include EXPORT -int api(const char * input_file, const char * resource_dir = "./resources", const char * mode = "holdem") { +int api(const char * input_file, const char * resource_dir = "./resources", const char * mode = "holdem", const char *log_file = "") { string input_file_ = input_file; string resource_dir_ = resource_dir; string mode_ = mode; + string log_file_ = log_file; - if(mode_ != "holdem" && mode_ != "shortdeck") - throw runtime_error(tfm::format("mode %s error, not in ['holdem','shortdeck']", mode_)); - - if(input_file_.empty()) { - CommandLineTool clt = CommandLineTool(mode_, resource_dir_); - clt.startWorking(); - }else{ - cout << "EXEC FROM FILE" << endl; - CommandLineTool clt = CommandLineTool(mode_, resource_dir_); - clt.execFromFile(input_file_); - } + return cmd_api(input_file_, resource_dir_, mode_, log_file_); } \ No newline at end of file diff --git a/src/card_format.cpp b/src/card_format.cpp new file mode 100644 index 0000000..0a908e9 --- /dev/null +++ b/src/card_format.cpp @@ -0,0 +1,32 @@ +#include "include/card_format.h" + +string toFormattedString(Card &card) { + QString qString = QString::fromStdString(card.getCard()); + qString = qString.replace("c", "♣️"); + qString = qString.replace("d", "♦️"); + qString = qString.replace("h", "♥️"); + qString = qString.replace("s", "♠️"); + return qString.toStdString(); +} + +QString toFormattedHtml(Card &card) { + QString qString = QString::fromStdString(card.getCard()); + if(qString.contains("c")) + qString = qString.replace("c", QString::fromLocal8Bit("♣<\/span>")); + else if(qString.contains("d")) + qString = qString.replace("d", QString::fromLocal8Bit("♦<\/span>")); + else if(qString.contains("h")) + qString = qString.replace("h", QString::fromLocal8Bit("♥<\/span>")); + else if(qString.contains("s")) + qString = qString.replace("s", QString::fromLocal8Bit("♠<\/span>")); + return qString; +} + +QString boardCards2html(vector& cards){ + QString ret_html = ""; + for(auto one_card:cards){ + if(one_card.empty())continue; + ret_html += toFormattedHtml(one_card); + } + return ret_html; +} diff --git a/src/compairer/Dic5Compairer.cpp b/src/compairer/Dic5Compairer.cpp index e1133ff..d6df6aa 100644 --- a/src/compairer/Dic5Compairer.cpp +++ b/src/compairer/Dic5Compairer.cpp @@ -5,9 +5,9 @@ #include "include/compairer/Dic5Compairer.h" #include -#include -#include -#include +// #include +// #include +// #include #include "time.h" #ifndef _MSC_VER #include "unistd.h" @@ -53,14 +53,14 @@ void FiveCardsStrength::convert(unordered_map& strength_map) { } } bool FiveCardsStrength::load(const char* file_path) { - //ifstream file(file_path, ios::binary); - /*if (!file) { + ifstream file(file_path, ios::binary); + if (!file.is_open()) { file.close(); - return false; - }*/ + /*return false; + } QFile file(QString::fromStdString(file_path)); - if (!file.open(QIODevice::ReadOnly)){ + if (!file.open(QIODevice::ReadOnly)){*/ throw runtime_error("unable to load compairer file"); } flush_map.clear(); other_map.clear(); @@ -90,7 +90,7 @@ bool FiveCardsStrength::save(const char* file_path) { //qDebug() << "b"; //file_path = "/Users/bytedance/Desktop/card5_dic_zipped_shortdeck.bin"; ofstream file(file_path, ios::binary); - if (!file) { + if (!file.is_open()) { file.close(); return false; } @@ -138,17 +138,22 @@ bool FiveCardsStrength::check(unordered_map& strength_map) { Dic5Compairer::Dic5Compairer(string dic_dir,int lines,string dic_dir_bin):Compairer(std::move(dic_dir),lines){ if(fcs.load(dic_dir_bin.c_str())) return; - QFile infile(QString::fromStdString(this->dic_dir)); + std::ifstream infile(this->dic_dir); + if(!infile.is_open()) { + throw runtime_error("unable to load compairer file"); + } + /*QFile infile(QString::fromStdString(this->dic_dir)); if (!infile.open(QIODevice::ReadOnly)){ throw runtime_error("unable to load compairer file"); } QTextStream in(&infile); - //progressbar bar(lines / 1000); + //progressbar bar(lines / 1000);*/ + string line; int i = 0; - //while (std::getline(infile, line)) - while (!in.atEnd()) + while (std::getline(infile, line)) + // while (!in.atEnd()) { - string line = in.readLine().toStdString(); + // string line = in.readLine().toStdString(); vector linesp = string_split(line,','); if(linesp.size() != 2){ throw runtime_error(tfm::format("linesp not correct: %s",line)); diff --git a/src/console.cpp b/src/console.cpp index bd55432..a858ef0 100644 --- a/src/console.cpp +++ b/src/console.cpp @@ -10,25 +10,13 @@ int main(int argc,const char **argv) { parser.addArgument("-i", "--input_file", 1, true); parser.addArgument("-r", "--resource_dir", 1, true); parser.addArgument("-m", "--mode", 1, true); + parser.addArgument("-l", "--log", 1, true); parser.parse(argc, argv); string input_file = parser.retrieve("input_file"); string resource_dir = parser.retrieve("resource_dir"); - if(resource_dir.empty()){ - resource_dir = "./resources"; - } string mode = parser.retrieve("mode"); - if(mode.empty()){mode = "holdem";} - if(mode != "holdem" && mode != "shortdeck") - throw runtime_error(tfm::format("mode %s error, not in ['holdem','shortdeck']",mode)); - - if(input_file.empty()) { - CommandLineTool clt = CommandLineTool(mode,resource_dir); - clt.startWorking(); - }else{ - cout << "EXEC FROM FILE" << endl; - CommandLineTool clt = CommandLineTool(mode,resource_dir); - clt.execFromFile(input_file); - } + string log_file = parser.retrieve("log"); + return cmd_api(input_file, resource_dir, mode, log_file); } diff --git a/src/library.cpp b/src/library.cpp index 03f2f10..1d16fb4 100644 --- a/src/library.cpp +++ b/src/library.cpp @@ -6,7 +6,7 @@ -vector string_split(string strin,char split){ +vector string_split(string &strin, char split){ vector retval; stringstream ss(strin); string token; diff --git a/src/nodes/GameActions.cpp b/src/nodes/GameActions.cpp index 920a9be..c361bfd 100644 --- a/src/nodes/GameActions.cpp +++ b/src/nodes/GameActions.cpp @@ -44,6 +44,8 @@ string GameActions::toString() { if(this->amount == -1) { return this->pokerActionToString(this->action); }else{ - return this->pokerActionToString(this->action) + " " + to_string(amount); + ostringstream oss; + oss << amount; + return this->pokerActionToString(this->action) + " " + oss.str(); } } diff --git a/src/runtime/PokerSolver.cpp b/src/runtime/PokerSolver.cpp index 07966f9..2fe9ecd 100644 --- a/src/runtime/PokerSolver.cpp +++ b/src/runtime/PokerSolver.cpp @@ -3,12 +3,35 @@ // #include "include/runtime/PokerSolver.h" - -PokerSolver::PokerSolver() { - +#ifdef USE_CUDA +#include "include/solver/cuda_cfr.h" +#endif + +PokerSolver::PokerSolver(PokerMode mode, string &resource_dir) { + string suits = "c,d,h,s"; + string ranks; + string compairer_file, compairer_file_bin; + int lines; + if(mode == PokerMode::HOLDEM){ + ranks = "2,3,4,5,6,7,8,9,T,J,Q,K,A"; + compairer_file = resource_dir + "/compairer/card5_dic_sorted.txt"; + compairer_file_bin = resource_dir + "/compairer/card5_dic_zipped.bin"; + lines = 2598961; + }else if(mode == PokerMode::SHORTDECK){ + ranks = "6,7,8,9,T,J,Q,K,A"; + compairer_file = resource_dir + "/compairer/card5_dic_sorted_shortdeck.txt"; + compairer_file_bin = resource_dir + "/compairer/card5_dic_zipped_shortdeck.bin"; + lines = 376993; + }else{ + throw runtime_error(tfm::format("mode not recognized : ",mode)); + } + init(ranks, suits, compairer_file, lines, compairer_file_bin); } -PokerSolver::PokerSolver(string ranks, string suits, string compairer_file,int compairer_file_lines, string compairer_file_bin) { +PokerSolver::PokerSolver(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin) { + init(ranks, suits, compairer_file, compairer_file_lines, compairer_file_bin); +} +void PokerSolver::init(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin) { vector ranks_vector = string_split(ranks,','); vector suits_vector = string_split(suits,','); this->deck = Deck(ranks_vector,suits_vector); @@ -69,7 +92,8 @@ void PokerSolver::stop(){ long long PokerSolver::estimate_tree_memory(string &p1_range, string &p2_range, string &board){ if(this->game_tree == nullptr){ - qDebug().noquote() << QObject::tr("Please buld tree first."); + // qDebug().noquote() << QObject::tr("Please buld tree first."); + logger->log("Please buld tree first."); return 0; } else{ @@ -85,8 +109,12 @@ long long PokerSolver::estimate_tree_memory(string &p1_range, string &p2_range, } } -void PokerSolver::train(string p1_range, string p2_range, string boards, string log_file, int iteration_number, - int print_interval, string algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads, int slice_cfr) { +void PokerSolver::train(string &p1_range, string &p2_range, string &boards, /*string &log_file,*/ int iteration_number, + int print_interval, string &algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads, int slice_cfr) { + if(game_tree == nullptr) { + logger->log("Please buld tree first."); + return; + } string player1RangeStr = p1_range; string player2RangeStr = p2_range; @@ -103,14 +131,19 @@ void PokerSolver::train(string p1_range, string p2_range, string boards, string this->player1Range = noDuplicateRange(range1,initial_board_long); this->player2Range = noDuplicateRange(range2,initial_board_long); - string logfile_name = log_file; + // string logfile_name = log_file; if(solver) solver.reset();// 释放内存 try { if(slice_cfr == 1) { - solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads); + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads, logger); } else if(slice_cfr == 2) { - solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads); +#ifdef USE_CUDA + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads, logger); +#else + logger->log("please set USE_CUDA ON in CMakeLists.txt and rebuild project"); + return; +#endif } else { solver = make_shared( @@ -123,7 +156,7 @@ void PokerSolver::train(string p1_range, string p2_range, string boards, string , iteration_number , false , print_interval - , logfile_name + , /*logfile_name*/logger , algorithm , Solver::MonteCarolAlg::NONE , warmup @@ -140,21 +173,23 @@ void PokerSolver::train(string p1_range, string p2_range, string boards, string } } -void PokerSolver::dump_strategy(QString dump_file,int dump_rounds) { +void PokerSolver::dump_strategy(string &dump_file, int dump_rounds) { //locale &loc=locale::global(locale(locale(),"",LC_CTYPE)); setlocale(LC_ALL,""); json dump_json = this->solver->dumps(false,dump_rounds); //QFile ofile( QString::fromStdString(dump_file)); ofstream fileWriter; - fileWriter.open(dump_file.toLocal8Bit()); + fileWriter.open(dump_file); if(!fileWriter.fail()){ fileWriter << dump_json; fileWriter.flush(); fileWriter.close(); - qDebug().noquote() << QObject::tr("save success"); + // qDebug().noquote() << QObject::tr("save success"); + logger->log("save success"); }else{ - qDebug().noquote() << QObject::tr("save failed, file cannot be open"); + // qDebug().noquote() << QObject::tr("save failed, file cannot be open"); + logger->log("save failed, file cannot be open"); } setlocale(LC_CTYPE, "C"); } diff --git a/src/runtime/qsolverjob.cpp b/src/runtime/qsolverjob.cpp index 4c71221..fa37936 100644 --- a/src/runtime/qsolverjob.cpp +++ b/src/runtime/qsolverjob.cpp @@ -10,9 +10,9 @@ void QSolverJob:: setContext(QSTextEdit * textEdit){ } PokerSolver* QSolverJob::get_solver(){ - if(this->mode == Mode::HOLDEM){ + if(this->mode == PokerMode::HOLDEM){ return &this->ps_holdem; - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ return &this->ps_shortdeck; }else throw runtime_error("unknown mode in get_solver"); } @@ -35,12 +35,14 @@ void QSolverJob::run() } catch (const runtime_error& error) { - qDebug().noquote() << tr("Encountering error:");//.toStdString() << endl; - qDebug().noquote() << error.what() << "\n"; + // qDebug().noquote() << tr("Encountering error:");//.toStdString() << endl; + // qDebug().noquote() << error.what() << "\n"; + logger->log("Encountering error:\n%s\n", error.what()); } } void QSolverJob::loading(){ + /* string suits = "c,d,h,s"; string ranks; this->resource_dir = ":/resources"; @@ -63,33 +65,50 @@ void QSolverJob::loading(){ lines = 376993; this->ps_shortdeck = PokerSolver(ranks,suits,compairer_file,lines,compairer_file_bin); qDebug().noquote() << tr("Loading finished. Good to go.");//.toStdString() << endl; + */ + resource_dir = "./resources"; + // qDebug().noquote() << tr("Loading holdem compairing file");//.toStdString() << endl; + logger->log("Loading holdem compairing file"); + ps_holdem = PokerSolver(PokerMode::HOLDEM, resource_dir); + // qDebug().noquote() << tr("Loading shortdeck compairing file");//.toStdString() << endl; + logger->log("Loading shortdeck compairing file"); + ps_shortdeck = PokerSolver(PokerMode::SHORTDECK, resource_dir); + // qDebug().noquote() << tr("Loading finished. Good to go.");//.toStdString() << endl; + logger->log("Loading finished. Good to go."); + ps_holdem.logger = logger; + ps_shortdeck.logger = logger; } void QSolverJob::saving(){ - qDebug().noquote() << tr("Saving json file..");//.toStdString() << std::endl; - + // qDebug().noquote() << tr("Saving json file..");//.toStdString() << std::endl; + logger->log("Saving json file.."); + /* QSettings setting("TexasSolver", "Setting"); setting.beginGroup("solver"); this->dump_rounds = setting.value("dump_round").toInt(); - - qDebug().noquote() << tr("Dump round: ") << this->dump_rounds; - if(this->dump_rounds == 3){ - qDebug().noquote() << tr("This could be slow, or even blow your RAM, dump to river is not well optimized :("); + */ + // qDebug().noquote() << tr("Dump round: ") << clt->dump_rounds; + logger->log("Dump round: %d", clt->dump_rounds); + if(clt->dump_rounds == 3){ + // qDebug().noquote() << tr("This could be slow, or even blow your RAM, dump to river is not well optimized :("); + logger->log("This could be slow, or even blow your RAM, dump to river is not well optimized"); } - if(this->mode == Mode::HOLDEM){ - this->ps_holdem.dump_strategy(this->savefile,this->dump_rounds); - }else if(this->mode == Mode::SHORTDECK){ - this->ps_shortdeck.dump_strategy(this->savefile,this->dump_rounds); + if(this->mode == PokerMode::HOLDEM){ + this->ps_holdem.dump_strategy(clt->res_file, clt->dump_rounds); + }else if(this->mode == PokerMode::SHORTDECK){ + this->ps_shortdeck.dump_strategy(clt->res_file, clt->dump_rounds); } - qDebug().noquote() << tr("Saving done.");//.toStdString() << std::endl; + // qDebug().noquote() << tr("Saving done.");//.toStdString() << std::endl; + logger->log("Saving done."); } void QSolverJob::stop(){ - qDebug().noquote() << tr("Trying to stop solver."); - if(this->mode == Mode::HOLDEM){ + // qDebug().noquote() << tr("Trying to stop solver."); + logger->log("Trying to stop solver."); + if(this->mode == PokerMode::HOLDEM){ this->ps_holdem.stop(); - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ this->ps_shortdeck.stop(); } } @@ -97,9 +116,12 @@ void QSolverJob::stop(){ void QSolverJob::solving(){ // TODO 为什么ui上多次求解会积累memory?哪里leak了? // TODO 为什么有时候会莫名闪退? - qDebug().noquote() << tr("Start Solving..");//.toStdString() << std::endl; - - if(this->mode == Mode::HOLDEM){ + // qDebug().noquote() << tr("Start Solving..");//.toStdString() << std::endl; + logger->log("Start Solving.."); + PokerSolver *ps = (mode == PokerMode::HOLDEM ? &ps_holdem : &ps_shortdeck); + clt->start_solve(ps); + /* + if(this->mode == PokerMode::HOLDEM){ this->ps_holdem.train( this->range_ip, this->range_oop, @@ -114,7 +136,7 @@ void QSolverJob::solving(){ this->use_halffloats, this->thread_number ); - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ this->ps_shortdeck.train( this->range_ip, this->range_oop, @@ -130,28 +152,34 @@ void QSolverJob::solving(){ this->thread_number ); } - qDebug().noquote() << tr("Solving done.");//.toStdString() << std::endl; + */ + // qDebug().noquote() << tr("Solving done.");//.toStdString() << std::endl; + logger->log("Solving done."); } -long long QSolverJob::estimate_tree_memory(QString range1,QString range2,QString board){ - qDebug().noquote() << tr("Estimating tree memory..");//.toStdString() << endl; - string p1_range = range1.toStdString(); - string p2_range = range2.toStdString(); - string board_str = board.toStdString(); - if(this->mode == Mode::HOLDEM){ - return ps_holdem.estimate_tree_memory(p1_range, p1_range, board_str); - }else if(this->mode == Mode::SHORTDECK){ - return ps_shortdeck.estimate_tree_memory(p1_range, p1_range, board_str); +long long QSolverJob::estimate_tree_memory(string &range1, string &range2, string &board) { + // qDebug().noquote() << tr("Estimating tree memory..");//.toStdString() << endl; + logger->log("Estimating tree memory.."); + if(this->mode == PokerMode::HOLDEM){ + return ps_holdem.estimate_tree_memory(range1, range2, board); + }else if(this->mode == PokerMode::SHORTDECK){ + return ps_shortdeck.estimate_tree_memory(range1, range2, board); } return 0; } void QSolverJob::build_tree(){ - qDebug().noquote() << tr("building tree..");//.toStdString() << endl; - if(this->mode == Mode::HOLDEM){ + // qDebug().noquote() << tr("building tree..");//.toStdString() << endl; + logger->log("building tree.."); + PokerSolver *ps = (mode == PokerMode::HOLDEM ? &ps_holdem : &ps_shortdeck); + clt->build_tree(ps); + /* + if(this->mode == PokerMode::HOLDEM){ ps_holdem.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ ps_shortdeck.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); } - qDebug().noquote() << tr("build tree finished");//.toStdString() << endl; + */ + // qDebug().noquote() << tr("build tree finished");//.toStdString() << endl; + logger->log("build tree finished"); } diff --git a/src/solver/BestResponse.cpp b/src/solver/BestResponse.cpp index 0c46b36..c5d4ac3 100644 --- a/src/solver/BestResponse.cpp +++ b/src/solver/BestResponse.cpp @@ -3,9 +3,9 @@ // #include "include/solver/BestResponse.h" -#include -#include -#include +// #include +// #include +// #include //#define DEBUG; BestResponse::BestResponse(vector> &private_combos, int player_number, @@ -40,7 +40,8 @@ float BestResponse::printExploitability(shared_ptr root, int itera if(this->reach_probs.empty()) this->reach_probs = vector> (this->player_number); - qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Iter: %s").toStdString().c_str(),iterationCount)); + // qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Iter: %s").toStdString().c_str(),iterationCount)); + logger->log("Iter: %d", iterationCount); float exploitible = 0; // 构造双方初始reach probs(按照手牌weights) for (int player_id = 0; player_id < this->player_number; player_id++) { @@ -54,10 +55,12 @@ float BestResponse::printExploitability(shared_ptr root, int itera for (int player_id = 0; player_id < this->player_number; player_id++) { float player_exploitability = getBestReponseEv(root, player_id, reach_probs, initialBoard, 0); exploitible += player_exploitability; - qDebug().noquote() << (QString::fromStdString(tfm::format(QObject::tr("player %s exploitability %s").toStdString().c_str(), player_id, player_exploitability))); + // qDebug().noquote() << (QString::fromStdString(tfm::format(QObject::tr("player %s exploitability %s").toStdString().c_str(), player_id, player_exploitability))); + logger->log("player %d exploitability %f", player_id, player_exploitability); } float total_exploitability = exploitible / this->player_number / initial_pot * 100; - qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Total exploitability %s precent").toStdString().c_str(), total_exploitability)); + // qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Total exploitability %s precent").toStdString().c_str(), total_exploitability)); + logger->log("Total exploitability %f precent", total_exploitability); return total_exploitability; } diff --git a/src/solver/PCfrSolver.cpp b/src/solver/PCfrSolver.cpp index d73aa53..a406a36 100644 --- a/src/solver/PCfrSolver.cpp +++ b/src/solver/PCfrSolver.cpp @@ -4,9 +4,9 @@ #include #include "include/solver/PCfrSolver.h" -#include -#include -#include +// #include +// #include +// #include //#define DEBUG; @@ -16,7 +16,7 @@ PCfrSolver::~PCfrSolver(){ PCfrSolver::PCfrSolver(shared_ptr tree, vector range1, vector range2, vector initial_board, shared_ptr compairer, Deck deck, int iteration_number, bool debug, - int print_interval, string logfile, string trainer, Solver::MonteCarolAlg monteCarolAlg,int warmup,float accuracy,bool use_isomorphism,int use_halffloats,int num_threads) :Solver(tree){ + int print_interval, /*string logfile*/Logger *logger, string trainer, Solver::MonteCarolAlg monteCarolAlg,int warmup,float accuracy,bool use_isomorphism,int use_halffloats,int num_threads):Solver(tree, logger){ this->initial_board = initial_board; this->initial_board_long = Card::boardInts2long(initial_board); this->logfile = logfile; @@ -53,7 +53,8 @@ PCfrSolver::PCfrSolver(shared_ptr tree, vector range1, v if(num_threads == -1){ num_threads = omp_get_num_procs(); } - qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Using %s threads").toStdString().c_str(),num_threads)); + // qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Using %s threads").toStdString().c_str(),num_threads)); + logger->log("Using %d threads", num_threads); this->num_threads = num_threads; this->distributing_task = false; omp_set_num_threads(this->num_threads); @@ -776,7 +777,7 @@ void PCfrSolver::train() { } BestResponse br = BestResponse(player_privates,this->player_number,this->pcm,this->rrm,this->deck,this->debug,this->color_iso_offset,this->split_round,this->num_threads,this->use_halffloats); - + br.logger = logger; br.printExploitability(tree->getRoot(), 0, tree->getRoot()->getPot(), initial_board_long); vector> reach_probs = this->getReachProbs(); @@ -802,9 +803,11 @@ void PCfrSolver::train() { if( (i % this->print_interval == 0 && i != 0 && i >= this->warmup) || this->nowstop) { endtime = timeSinceEpochMillisec(); long time_ms = endtime - begintime; - qDebug().noquote() << "-------------------"; + // qDebug().noquote() << "-------------------"; + logger->log("-------------------"); float expliotibility = br.printExploitability(tree->getRoot(), i + 1, tree->getRoot()->getPot(), initial_board_long); - qDebug().noquote() << QObject::tr("time used: ") << float(time_ms) / 1000 << QObject::tr(" second."); + // qDebug().noquote() << QObject::tr("time used: ") << float(time_ms) / 1000 << QObject::tr(" second."); + logger->log("time used: %f second.", float(time_ms) / 1000); if(!this->logfile.empty()){ json jo; jo["iteration"] = i; @@ -823,7 +826,8 @@ void PCfrSolver::train() { } } - qDebug().noquote() << QObject::tr("collecting statics"); + // qDebug().noquote() << QObject::tr("collecting statics"); + logger->log("collecting statics"); this->collecting_statics = true; for(int player_id = 0;player_id < this->player_number;player_id ++) { this->round_deal = vector{-1,-1,-1,-1}; @@ -838,8 +842,8 @@ void PCfrSolver::train() { } this->collecting_statics = false; this->statics_collected = true; - qDebug().noquote() << QObject::tr("statics collected"); - + // qDebug().noquote() << QObject::tr("statics collected"); + logger->log("statics collected"); if(!this->logfile.empty()) { fileWriter.flush(); fileWriter.close(); @@ -935,14 +939,14 @@ void PCfrSolver::reConvertJson(const shared_ptr& node,json& strate shared_ptr childerns = chanceNode->getChildren(); vector card_strs; for(Card card:cards) - card_strs.push_back(card.toString()); + card_strs.push_back(card.getCard()); json& dealcards = (*retval)["dealcards"]; for(std::size_t i = 0;i < cards.size();i ++){ vector> new_exchange_color_list(exchange_color_list); Card& one_card = const_cast(cards[i]); vector new_prefix(prefix); - new_prefix.push_back("Chance:" + one_card.toString()); + new_prefix.push_back("Chance:" + one_card.getCard()); std::size_t card = i; @@ -984,7 +988,7 @@ void PCfrSolver::reConvertJson(const shared_ptr& node,json& strate throw runtime_error("exchange color list shouldn't be exceed size 1 here"); } - string one_card_str = one_card.toString(); + string one_card_str = one_card.getCard(); if(exchange_color_list.size() == 1) { int rank1 = exchange_color_list[0][0]; int rank2 = exchange_color_list[0][1]; diff --git a/src/solver/Solver.cpp b/src/solver/Solver.cpp index 18424ad..5fd4677 100644 --- a/src/solver/Solver.cpp +++ b/src/solver/Solver.cpp @@ -8,7 +8,7 @@ Solver::Solver() { } -Solver::Solver(shared_ptr tree) { +Solver::Solver(shared_ptr tree, Logger *logger):logger(logger) { this->tree = tree; } diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 0f2aa9e..7c43e91 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -287,8 +287,8 @@ size_t SliceCFR::init_leaf_node() { } } sd_offset = leaf_node_dfs[FOLD_TYPE].size(); - printf("%zd,%zd\n", pre_leaf_node[P0].size(), pre_leaf_node[P1].size()); - printf("%d,%d,%zd,%zd\n", n_leaf_node, node_idx, root_child_idx[P0].size(), root_child_idx[P1].size()); + logger->log("%zd,%zd", pre_leaf_node[P0].size(), pre_leaf_node[P1].size()); + logger->log("%d,%d,%zd,%zd", n_leaf_node, node_idx, root_child_idx[P0].size(), root_child_idx[P1].size()); size_t max_val[N_PLAYER] = {0, 0}, min_val[N_PLAYER] = {INT_MAX, INT_MAX}; for(int i = 0; i < N_PLAYER; i++) { @@ -302,7 +302,7 @@ size_t SliceCFR::init_leaf_node() { min_val[i] = min(min_val[i], node.leaf_node_idx.size()); } } - printf("%zd,%zd,%zd,%zd\n", min_val[P0], max_val[P0], min_val[P1], max_val[P1]); + logger->log("%zd,%zd,%zd,%zd", min_val[P0], max_val[P0], min_val[P1], max_val[P1]); ev[FOLD_TYPE].insert(ev[FOLD_TYPE].end(), ev[SHOWDOWN_TYPE].begin(), ev[SHOWDOWN_TYPE].end()); ev[SHOWDOWN_TYPE].clear(); @@ -323,8 +323,9 @@ SliceCFR::SliceCFR( int train_step, int print_interval, float accuracy, - int n_thread -):tree(tree), deck(deck), steps(train_step), interval(print_interval), n_thread(max(0,n_thread)), rrm(compairer) { + int n_thread, + Logger *logger +):deck(deck), steps(train_step), interval(print_interval), n_thread(max(0,n_thread)), rrm(compairer), Solver(tree, logger) { init_board = Card::boardInts2long(initial_board); init_round = GameTreeNode::gameRound2int(tree->getRoot()->getRound()); if(init_round < FLOP_ROUND) return; @@ -343,7 +344,7 @@ SliceCFR::SliceCFR( void SliceCFR::init() { float unit = 1 << 20; size_t size = estimate_tree_size(); - printf("estimate memory:%f MB\n", size/unit); + logger->log("estimate memory:%f MB", size/unit); leaf_node_dfs.resize(N_LEAF_TYPE); ev.resize(N_LEAF_TYPE); @@ -367,7 +368,7 @@ void SliceCFR::init() { if(dfs_idx == 0 || dfs_node[0].n_act == 0) return; size = init_memory(); - printf("%d nodes, total:%f MB\n", dfs_idx, size/unit); + logger->log("%d nodes, total:%f MB", dfs_idx, size/unit); init_succ = true; } @@ -743,7 +744,7 @@ void SliceCFR::train() { // _rm(P1, false); // _reach_prob(P0, false); vector res = exploitability(); - printf("0:%f %f %f\n", res[0], res[1], (res[0]+res[1])/2); + logger->log("0:%f %f %f", res[0], res[1], (res[0]+res[1])/2); // 计算exploitability后,双方的rm和p0的reach_prob已经恢复 pos_coef = neg_coef = coef = 0; double temp = 0; @@ -763,9 +764,9 @@ void SliceCFR::train() { cnt = 0; res = exploitability(); total = timeSinceEpochMillisec() - start; - printf("%d:%.3f,%.3fs\n", iter, timer.ms()/1000.0, total/1000.0); + logger->log("%d:%.3f,%.3fs", iter, timer.ms()/1000.0, total/1000.0); temp = (res[0] + res[1]) / 2; - printf("%d:%f %f %f\n", iter, res[0], res[1], temp); + logger->log("%d:%f %f %f", iter, res[0], res[1], temp); if(temp <= tol) break; } if(stop_flag) break; @@ -773,8 +774,8 @@ void SliceCFR::train() { if(cnt) { res = exploitability(); total = timeSinceEpochMillisec() - start; - printf("%d:%.3f,%.3fs\n", iter, timer.ms()/1000.0, total/1000.0); - printf("%d:%f %f %f\n", iter, res[0], res[1], (res[0]+res[1])/2); + logger->log("%d:%.3f,%.3fs", iter, timer.ms()/1000.0, total/1000.0); + logger->log("%d:%f %f %f", iter, res[0], res[1], (res[0]+res[1])/2); } } diff --git a/src/tools/CommandLineTool.cpp b/src/tools/CommandLineTool.cpp index a0cd885..31a3994 100644 --- a/src/tools/CommandLineTool.cpp +++ b/src/tools/CommandLineTool.cpp @@ -2,39 +2,24 @@ // Created by bytedance on 7.6.21. // #include "include/tools/CommandLineTool.h" -#include - -CommandLineTool::CommandLineTool(string mode,string resource_dir) { - string suits = "c,d,h,s"; - string ranks; - this->resource_dir = resource_dir; - string compairer_file,compairer_file_bin; - int lines; - if(mode == "holdem"){ - ranks = "2,3,4,5,6,7,8,9,T,J,Q,K,A"; - compairer_file = this->resource_dir + "/compairer/card5_dic_sorted.txt"; - compairer_file_bin = this->resource_dir + "/compairer/card5_dic_zipped.bin"; - lines = 2598961; - }else if(mode == "shortdeck"){ - ranks = "6,7,8,9,T,J,Q,K,A"; - compairer_file = this->resource_dir + "/compairer/card5_dic_sorted_shortdeck.txt"; - compairer_file_bin = this->resource_dir + "/compairer/card5_dic_zipped_shortdeck.bin"; - lines = 376993; - }else{ - throw runtime_error(tfm::format("mode not recognized : ",mode)); - } - string logfile_name = "../resources/outputs/outputs_log.txt"; - this->ps = PokerSolver(ranks,suits,compairer_file,lines,compairer_file_bin); +// #include +#include +#include +#include + +CommandLineTool::CommandLineTool() { + // string logfile_name = "../resources/outputs/outputs_log.txt"; + // this->ps = PokerSolver(mode, resource_dir); - StreetSetting gbs_flop_ip = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_turn_ip = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_river_ip = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_flop_ip = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_turn_ip = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_river_ip = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_flop_oop = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_turn_oop = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_river_oop = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_flop_oop = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_turn_oop = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_river_oop = StreetSetting(vector{},vector{},vector{},true); - this->gtbs = make_shared(gbs_flop_ip,gbs_turn_ip,gbs_river_ip,gbs_flop_oop,gbs_turn_oop,gbs_river_oop); + // this->gtbs = make_shared(gbs_flop_ip,gbs_turn_ip,gbs_river_ip,gbs_flop_oop,gbs_turn_oop,gbs_river_oop); //ps.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); //cout << "build tree finished" << endl; /* @@ -56,41 +41,127 @@ CommandLineTool::CommandLineTool(string mode,string resource_dir) { */ } -void CommandLineTool::startWorking() { +void CommandLineTool::startWorking(PokerSolver *ps) { string input_line; while(cin) { getline(cin, input_line); - this->processCommand(input_line); + this->processCommand(input_line, ps); }; } -void CommandLineTool::execFromFile(string input_file){ +void CommandLineTool::execFromFile(const char *input_file, PokerSolver *ps) { std::ifstream infile(input_file); std::string input_line; while (std::getline(infile, input_line)) { - this->processCommand(input_line); + this->processCommand(input_line, ps); } } -void split(const string& s, char c, - vector& v) { - string::size_type i = 0; - string::size_type j = s.find(c); - - while (j != string::npos) { +void split(const string& s, char delimiter, vector& v) { + size_t i = s.find_first_not_of(delimiter), j = 0; + while (i != string::npos) { + j = s.find_first_of(delimiter, i+1); + if(j == string::npos) j = s.size(); v.push_back(s.substr(i, j-i)); - i = ++j; - j = s.find(c, j); + i = s.find_first_not_of(delimiter, j+1); + } +} - if (j == string::npos) - v.push_back(s.substr(i, s.length())); +template +string tostring(T val) { + string s = to_string(val); + for(size_t i = s.size() - 1; i > 0; i--) { + if(s[i] == '0') s.pop_back(); + else if(s[i] == '.') { + s.pop_back(); + break; + } + else break; } + return s; +} + +template +string tostring_oss(T val) { + ostringstream oss; + oss << val; + return oss.str(); } +void join(const vector &vec, char delimiter, string &out) { + size_t n = vec.size(); + if(n) out += tostring(vec[0]); + for(int i = 1; i < n; i++) { + out += delimiter; + out += tostring(vec[i]); + } +} -void CommandLineTool::processCommand(string input) { +bool CommandLineTool::set_board(string &str) { + board = str; + vector board_str_arr = string_split(board,','); + if(board_str_arr.size() == 3){ + this->current_round = 1; + }else if(board_str_arr.size() == 4){ + this->current_round = 2; + }else if(board_str_arr.size() == 5){ + this->current_round = 3; + }else{ + // throw runtime_error(tfm::format("board %s not recognized",this->board)); + return false; + } + return true; +} + +bool CommandLineTool::set_bet_sizes(string &str, char delimiter, vector *sizes) { + vector params; + split(str, delimiter, params); + int start = (sizes != nullptr ? 0 : 3); + if(params.size() < start) { + // throw runtime_error("param number error"); + return false; + } + if(sizes == nullptr) { + // oop,turn,bet,30,70,100 + StreetSetting& streetSetting = gtbs.get_setting(params[0], params[1]); + string &bet_type = params[2]; + if(bet_type == "allin") { + if(params.size() == start) streetSetting.allin = true; + else streetSetting.allin = stoi(params[start]); + } + else if(bet_type == "bet") sizes = &(streetSetting.bet_sizes); + else if(bet_type == "raise") sizes = &(streetSetting.raise_sizes); + else if(bet_type == "donk") sizes = &(streetSetting.donk_sizes); + else return false; + } + if(sizes != nullptr) { + sizes->clear(); + std::unordered_set seen; + for(std::size_t i = start; i < params.size(); i++) { + float val = stof(params[i]); + if(seen.count(val)) continue; + sizes->push_back(val); + seen.insert(val); + } + std::sort(sizes->begin(), sizes->end()); + } + return true; +} + +// void show_bet_sizes(std::ofstream &out, const char *player, const char *round, const char *type, vector &sizes) { +// string s; +// join(sizes, ',', s); +// out << "set_bet_sizes " << player << ',' << round << ',' << type; +// if(s.size()) out << ',' << s; +// out << endl; +// } +// void show_bet_sizes(std::ofstream &out, const char *player, const char *round, const char *type, bool allin) { +// out << "set_bet_sizes " << player << ',' << round << ',' << type << ',' << allin; +// } + +void CommandLineTool::processCommand(string &input, PokerSolver *ps) { vector contents; if(input.empty() || input[0] == '#') return; split(input,' ',contents); @@ -99,48 +170,17 @@ void CommandLineTool::processCommand(string input) { string command = contents[0]; string paramstr = contents.size() == 1 ? "" : contents[1]; if(command == "set_pot"){ - this->ip_commit = stof(paramstr) / 2; - this->oop_commit = stof(paramstr) / 2; + set_pot(stof(paramstr)); }else if(command == "set_effective_stack"){ - this->stack = stof(paramstr) + this->ip_commit; + set_effective_stack(stof(paramstr)); }else if(command == "set_board"){ - this->board = paramstr; - vector board_str_arr = string_split(board,','); - if(board_str_arr.size() == 3){ - this->current_round = 1; - }else if(board_str_arr.size() == 4){ - this->current_round = 2; - }else if(board_str_arr.size() == 5){ - this->current_round = 3; - }else{ - throw runtime_error(tfm::format("board %s not recognized",this->board)); - } + set_board(paramstr); }else if(command == "set_range_ip"){ this->range_ip = paramstr; }else if(command == "set_range_oop"){ this->range_oop = paramstr; }else if(command == "set_bet_sizes"){ - vector params; - split(paramstr,',',params); - if(params.size() < 3)throw runtime_error("param number error"); - // oop,turn,bet,30,70,100 - string player = params[0]; - string round = params[1]; - string bet_type = params[2]; - StreetSetting& streetSetting = this->gtbs->get_setting(player,round); - vector* sizes; - if(bet_type == "allin") streetSetting.allin = true; - else if(bet_type == "bet") sizes = &(streetSetting.bet_sizes); - else if(bet_type == "raise") sizes = &(streetSetting.raise_sizes); - else if(bet_type == "donk") sizes = &(streetSetting.donk_sizes); - else throw runtime_error(""); - - if(bet_type == "bet" || bet_type == "raise" || bet_type == "donk"){ - sizes->clear(); - for(std::size_t i = 3;i < params.size();i ++ ){ - sizes->push_back(stof(params[i])); - } - } + set_bet_sizes(paramstr); }else if(command == "set_raise_limit"){ this->raise_limit = stoi(paramstr); }else if(command == "set_accuracy"){ @@ -148,9 +188,9 @@ void CommandLineTool::processCommand(string input) { }else if(command == "set_allin_threshold"){ this->allin_threshold = stof(paramstr); }else if(command == "set_thread_num"){ - this->thread_number = stoi(paramstr); + this->thread_num = stoi(paramstr); }else if(command == "build_tree"){ - this->ps.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); + build_tree(ps); }else if(command == "set_max_iteration"){ this->max_iteration = stoi(paramstr); }else if(command == "set_use_isomorphism"){ @@ -158,43 +198,103 @@ void CommandLineTool::processCommand(string input) { }else if(command == "set_print_interval"){ this->print_interval = stoi(paramstr); }else if(command == "start_solve"){ - cout << "<<>>" << endl; - this->ps.train( - this->range_ip, - this->range_oop, - this->board, - "tmp_log.txt", - max_iteration, - this->print_interval, - "discounted_cfr", - -1, - this->accuracy, - this->use_isomorphism, - 0, // TODO: enable half float option for command line tool - this->thread_number, - slice_cfr - ); + start_solve(ps); + }else if(command == "dump_setting"){ + dump_setting(paramstr.c_str()); }else if(command == "dump_result"){ - string output_file = paramstr; - this->ps.dump_strategy(QString::fromStdString(output_file),this->dump_rounds); + res_file = paramstr; + if(!ps) return; + ps->dump_strategy(res_file, this->dump_rounds); }else if(command == "set_dump_rounds"){ this->dump_rounds = stoi(paramstr); }else if(command == "estimate_tree_memory"){ + if(!ps) return; if(range_ip.empty() || range_oop.empty() || board.empty()) { - cout << "Please set range_ip, range_oop and board first." << endl; + // cout << "Please set range_ip, range_oop and board first." << endl; + logger->log("Please set range_ip, range_oop and board first."); return; } - shared_ptr game_tree = ps.get_game_tree(); + shared_ptr game_tree = ps->get_game_tree(); if(game_tree == nullptr) { - cout << "Please buld tree first." << endl; + // cout << "Please buld tree first." << endl; + logger->log("Please buld tree first."); return; } - long long size = ps.estimate_tree_memory(range_ip, range_oop, board); + long long size = ps->estimate_tree_memory(range_ip, range_oop, board); size *= sizeof(float); - cout << (float)size / (1024*1024) << " MB" << endl; + // cout << (float)size / (1024*1024) << " MB" << endl; + logger->log("estimate_tree_memory: %f MB", (float)size / (1024*1024)); }else if(command == "set_slice_cfr"){ slice_cfr = stoi(paramstr); }else{ - cout << "command not recognized: " << command << endl; + // cout << "command not recognized: " << command << endl; + logger->log("command not recognized: %s", command.c_str()); } } + +void CommandLineTool::dump_setting(const char *file) { + static vector player {"oop","ip"}; + static vector round {"flop","turn","river"}; + static vector type {"bet","raise","donk","allin"}; + std::ofstream out(file); + out << "set_pot " << get_pot() << endl; + out << "set_effective_stack " << get_effective_stack() << endl; + out << "set_board " << board << endl; + out << "set_range_oop " << range_oop << endl; + out << "set_range_ip " << range_ip << endl; + + for(size_t i = 0; i < player.size(); i++) { + for(size_t j = 0; j < round.size(); j++) { + for(size_t k = 0; k < type.size(); k++) { + if(k == 2 && (i == 1 || j == 0)) continue;// no donk:ip, oop flop + out << "set_bet_sizes " << player[i] << ',' << round[j] << ',' << type[k] << ','; + StreetSetting& st = gtbs.get_setting(player[i], round[j]); + if(k == 3) out << st.allin; + else { + vector &vec = (k == 0 ? st.bet_sizes : (k == 1 ? st.raise_sizes : st.donk_sizes)); + string str; + join(vec, ',', str); + out << str; + } + out << endl; + } + } + } + out << "set_allin_threshold " << allin_threshold << endl; + out << "set_raise_limit " << raise_limit << endl; + out << "build_tree" << endl; + out << "set_thread_num " << thread_num << endl; + out << "set_accuracy " << accuracy << endl; + out << "set_max_iteration " << max_iteration << endl; + out << "set_print_interval " << print_interval << endl; + out << "set_use_isomorphism " << use_isomorphism << endl; + out << "set_slice_cfr " << slice_cfr << endl; + out << "start_solve" << endl; + out << "set_dump_rounds " << dump_rounds << endl; + out << "dump_result " << res_file << endl; + out.close(); +} + +int cmd_api(string &input_file, string &resource_dir, string &mode, string &log_file) { + if(resource_dir.empty()){ + resource_dir = "./resources"; + } + if(log_file.empty()) log_file = get_localtime() + ".txt"; + Logger logger(true, log_file.c_str(), "w", true, true, 1); + PokerMode poker_mode = PokerMode::UNKNOWN; + if(mode.empty() || mode == "holdem") poker_mode = PokerMode::HOLDEM; + else if(mode == "shortdeck") poker_mode = PokerMode::SHORTDECK; + else throw runtime_error(tfm::format("mode %s error, not in ['holdem','shortdeck']", mode)); + PokerSolver ps = PokerSolver(poker_mode, resource_dir); + CommandLineTool clt; + clt.logger = &logger; + ps.logger = &logger; + if(input_file.empty()) { + clt.startWorking(&ps); + }else{ + // cout << "EXEC FROM FILE" << endl; + logger.log("EXEC FROM FILE"); + clt.execFromFile(input_file.c_str(), &ps); + } + return 0; +} \ No newline at end of file diff --git a/src/tools/GameTreeBuildingSettings.cpp b/src/tools/GameTreeBuildingSettings.cpp index a815a4c..422c3ed 100644 --- a/src/tools/GameTreeBuildingSettings.cpp +++ b/src/tools/GameTreeBuildingSettings.cpp @@ -13,7 +13,7 @@ GameTreeBuildingSettings::GameTreeBuildingSettings( StreetSetting river_oop):flop_ip(flop_ip),turn_ip(turn_ip),river_ip(river_ip),flop_oop(flop_oop),turn_oop(turn_oop),river_oop(river_oop) { } -StreetSetting& GameTreeBuildingSettings::get_setting(string player,string round){ +StreetSetting& GameTreeBuildingSettings::get_setting(string &player, string &round){ if(player == "ip" && round == "flop") return flop_ip; else if(player == "ip" && round == "turn") return turn_ip; else if(player == "ip" && round == "river") return river_ip; diff --git a/src/tools/logger.cpp b/src/tools/logger.cpp new file mode 100644 index 0000000..60bbca8 --- /dev/null +++ b/src/tools/logger.cpp @@ -0,0 +1,47 @@ +#include "include/tools/logger.h" +#include + +void get_localtime(char *buf, size_t n, const char *format) { + using namespace std::chrono; + system_clock::time_point tp = system_clock::now(); + time_t now = system_clock::to_time_t(tp); + // time(&now); + int ms = duration_cast(tp.time_since_epoch()).count() - now * 1000; + tm tm_now; + localtime_s(&tm_now, &now); + // strftime(buf, n, format, &tm_now); + sprintf_s(buf, n, format, tm_now.tm_year+1900, tm_now.tm_mon+1, tm_now.tm_mday, + tm_now.tm_hour, tm_now.tm_min, tm_now.tm_sec, ms); +} + +string get_localtime() { + char buf[25]; + get_localtime(buf, sizeof(buf), "%d_%02d_%02d_%02d_%02d_%02d.%03d"); + return string(buf); +} + +void Logger::log(const char *format, ...) { + if(timestamp) log_time(); + va_list args = nullptr; + va_start(args, format); + if(file) { + vfprintf(file, format, args); + if((++step) == period) { + step = 0; + fflush(file); + } + if(new_line) fprintf(file, "\n"); + } + if(cmd) { + vprintf(format, args); + if(new_line) printf("\n"); + } + va_end(args); +} + +void Logger::log_time() { + char buf[28]; + get_localtime(buf, sizeof(buf), "%d-%02d-%02d %02d:%02d:%02d.%03d "); + if(file) fprintf(file, buf); + if(cmd) printf(buf); +} \ No newline at end of file diff --git a/src/ui/boardselectortabledelegate.cpp b/src/ui/boardselectortabledelegate.cpp index a2b0afe..a281058 100644 --- a/src/ui/boardselectortabledelegate.cpp +++ b/src/ui/boardselectortabledelegate.cpp @@ -28,7 +28,8 @@ void BoardSelectorTableDelegate::paint(QPainter *painter, const QStyleOptionView painter->fillRect(rect, brush); QTextDocument doc; - doc.setHtml(Card(options.text.toStdString()).toFormattedHtml()); + Card card(options.text.toStdString()); + doc.setHtml(toFormattedHtml(card)); painter->translate(options.rect.left(), options.rect.top()); QRect clip(0, 0, options.rect.width(), options.rect.height()); diff --git a/src/ui/detailitemdelegate.cpp b/src/ui/detailitemdelegate.cpp index 4d6040b..d79cedd 100644 --- a/src/ui/detailitemdelegate.cpp +++ b/src/ui/detailitemdelegate.cpp @@ -114,8 +114,8 @@ void DetailItemDelegate::paint_strategy(QPainter *painter, const QStyleOptionVie } } options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[card1].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[card2].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card1]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card2]); options.text = "

" + options.text + "<\/h2>"; for(std::size_t i = 0;i < strategy.size();i ++){ GameActions one_action = gameActions[i]; @@ -189,8 +189,8 @@ void DetailItemDelegate::paint_range(QPainter *painter, const QStyleOptionViewIt painter->fillRect(rect, brush); options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[cord.first].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[cord.second].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[cord.first]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[cord.second]); options.text = "

" + options.text + "<\/h2>"; options.text += QString("

%1<\/h2>").arg(QString::number(range_number,'f',3)); @@ -318,8 +318,8 @@ void DetailItemDelegate::paint_evs(QPainter *painter, const QStyleOptionViewItem } options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[card1].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[card2].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card1]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card2]); options.text = "

" + options.text + "<\/h2>"; for(std::size_t i = 0;i < evs.size();i ++){ GameActions one_action = gameActions[i]; @@ -378,7 +378,7 @@ void DetailItemDelegate::paint_evs_only(QPainter *painter, const QStyleOptionVie if(ind < evs.size() and ind < strategy_number) { float one_ev = evs[ind]; - float normalized_ev = normalization_tanh(detailViewerModel->tableStrategyModel->get_solver()->stack,one_ev); + float normalized_ev = normalization_tanh(detailViewerModel->tableStrategyModel->get_solver()->clt->stack,one_ev); //options.text += QString("
%1").arg(QString::number(normalized_ev)); pair strategy_ui_table = detailViewerModel->tableStrategyModel->ui_strategy_table[this->detailWindowSetting->grid_i][this->detailWindowSetting->grid_j][ind]; @@ -396,8 +396,8 @@ void DetailItemDelegate::paint_evs_only(QPainter *painter, const QStyleOptionVie painter->fillRect(rect, brush); options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[card1].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[card2].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card1]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card2]); options.text = "

" + options.text + "<\/h2>"; options.text += QString("

%1<\/h2>").arg(QString::number(one_ev,'f',3)); diff --git a/src/ui/strategyitemdelegate.cpp b/src/ui/strategyitemdelegate.cpp index b2adb6b..7998324 100644 --- a/src/ui/strategyitemdelegate.cpp +++ b/src/ui/strategyitemdelegate.cpp @@ -198,7 +198,7 @@ void StrategyItemDelegate::paint_evs(QPainter *painter, const QStyleOptionViewIt int last_left = 0; for(std::size_t i = 0;i < evs.size();i ++ ){ float one_ev = evs[evs.size() - i - 1]; - float normalized_ev = normalization_tanh(this->qSolverJob->stack,one_ev); + float normalized_ev = normalization_tanh(this->qSolverJob->clt->stack,one_ev); //options.text += QString("
%1").arg(QString::number(normalized_ev)); int red = max((int)(255 - normalized_ev * 255),0); diff --git a/src/ui/treemodel.cpp b/src/ui/treemodel.cpp index 8c94f6c..0835323 100644 --- a/src/ui/treemodel.cpp +++ b/src/ui/treemodel.cpp @@ -120,9 +120,9 @@ void TreeModel::reGenerateTreeItem(GameTreeNode::GameRound round,TreeItem* node_ void TreeModel::setupModelData() { PokerSolver * solver; - if(this->qSolverJob->mode == QSolverJob::Mode::HOLDEM){ + if(this->qSolverJob->mode == PokerMode::HOLDEM){ solver = &(this->qSolverJob->ps_holdem); - }else if(this->qSolverJob->mode == QSolverJob::Mode::SHORTDECK){ + }else if(this->qSolverJob->mode == PokerMode::SHORTDECK){ solver = &(this->qSolverJob->ps_shortdeck); }else{ throw runtime_error("holdem mode incorrect"); diff --git a/strategyexplorer.cpp b/strategyexplorer.cpp index 0f47162..ebfbbd1 100644 --- a/strategyexplorer.cpp +++ b/strategyexplorer.cpp @@ -45,10 +45,10 @@ StrategyExplorer::StrategyExplorer(QWidget *parent,QSolverJob * qSolverJob) : Deck* deck = this->qSolverJob->get_solver()->get_deck(); int index = 0; - QString board_qstring = QString::fromStdString(this->qSolverJob->board); + QString board_qstring = QString::fromStdString(this->qSolverJob->clt->board); for(Card one_card: deck->getCards()){ - if(board_qstring.contains(QString::fromStdString(one_card.toString())))continue; - QString card_str_formatted = QString::fromStdString(one_card.toFormattedString()); + if(board_qstring.contains(QString::fromStdString(one_card.getCard())))continue; + QString card_str_formatted = QString::fromStdString(toFormattedString(one_card)); this->ui->turnCardBox->addItem(card_str_formatted); this->ui->riverCardBox->addItem(card_str_formatted); @@ -120,7 +120,7 @@ void StrategyExplorer::item_expanded(const QModelIndex& index){ } void StrategyExplorer::process_board(TreeItem* treeitem){ - vector board_str_arr = string_split(this->qSolverJob->board,','); + vector board_str_arr = string_split(this->qSolverJob->clt->board,','); vector cards; for(string one_board_str:board_str_arr){ cards.push_back(Card(one_board_str)); @@ -136,7 +136,7 @@ void StrategyExplorer::process_board(TreeItem* treeitem){ cards.push_back(Card(this->tableStrategyModel->getRiverCard())); } } - this->ui->boardLabel->setText(QString("%1: ").arg(tr("board")) + Card::boardCards2html(cards)); + this->ui->boardLabel->setText(QString("%1: ").arg(tr("board")) + boardCards2html(cards)); } void StrategyExplorer::process_treeclick(TreeItem* treeitem){ From a1aab796ad2e8d69fbee718fe23bf7171375245b Mon Sep 17 00:00:00 2001 From: yffbit Date: Wed, 29 May 2024 15:45:06 +0800 Subject: [PATCH 12/19] update --- include/tools/logger.h | 5 ++--- mainwindow.h | 6 +++++- src/tools/logger.cpp | 14 ++++++++++++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/include/tools/logger.h b/include/tools/logger.h index ad17ac9..8ecde8e 100644 --- a/include/tools/logger.h +++ b/include/tools/logger.h @@ -16,9 +16,8 @@ class Logger { Logger(bool cmd, const char *path, const char *mode = "w", bool timestamp = false, bool new_line = true, int period = 10) :cmd(cmd), timestamp(timestamp), new_line(new_line), period(period) { if(path) { - errno_t err = fopen_s(&file, path, mode); - if(err) printf("%d\n", err); - if(!file) printf("create file %s failed\n", path); + file = fopen(path, mode); + if(!file) printf("failed to create file %s\n", path); } } virtual ~Logger() { diff --git a/mainwindow.h b/mainwindow.h index 8efce20..ca902d0 100644 --- a/mainwindow.h +++ b/mainwindow.h @@ -20,7 +20,7 @@ class QLogger : public Logger { QLogger(const char *path, const char *mode = "w", bool timestamp = false, int period = 10):Logger(false, path, mode, timestamp, true, period) {} virtual void log(const char *format, ...) { if(timestamp) log_time(); - va_list args = nullptr; + va_list args; va_start(args, format); if(file) { vfprintf(file, format, args); @@ -29,6 +29,10 @@ class QLogger : public Logger { fflush(file); } if(new_line) fprintf(file, "\n"); +#ifdef __GNUC__ + va_end(args); + va_start(args, format); +#endif } // qDebug().noquote() << QString::vasprintf(QObject::tr(format).toLocal8Bit(), args); qDebug().noquote() << QString::vasprintf(QObject::tr(format).toStdString().c_str(), args); diff --git a/src/tools/logger.cpp b/src/tools/logger.cpp index 60bbca8..ebc534c 100644 --- a/src/tools/logger.cpp +++ b/src/tools/logger.cpp @@ -8,9 +8,13 @@ void get_localtime(char *buf, size_t n, const char *format) { // time(&now); int ms = duration_cast(tp.time_since_epoch()).count() - now * 1000; tm tm_now; +#ifdef _MSC_VER localtime_s(&tm_now, &now); +#else + localtime_r(&now, &tm_now); +#endif // strftime(buf, n, format, &tm_now); - sprintf_s(buf, n, format, tm_now.tm_year+1900, tm_now.tm_mon+1, tm_now.tm_mday, + snprintf(buf, n, format, tm_now.tm_year+1900, tm_now.tm_mon+1, tm_now.tm_mday, tm_now.tm_hour, tm_now.tm_min, tm_now.tm_sec, ms); } @@ -22,7 +26,7 @@ string get_localtime() { void Logger::log(const char *format, ...) { if(timestamp) log_time(); - va_list args = nullptr; + va_list args; va_start(args, format); if(file) { vfprintf(file, format, args); @@ -31,6 +35,12 @@ void Logger::log(const char *format, ...) { fflush(file); } if(new_line) fprintf(file, "\n"); +#ifdef __GNUC__ + if(cmd) { + va_end(args); + va_start(args, format); + } +#endif } if(cmd) { vprintf(format, args); From 4e44daf2152167f3ff173fc622a374aa02ca2e59 Mon Sep 17 00:00:00 2001 From: yffbit Date: Wed, 29 May 2024 15:58:19 +0800 Subject: [PATCH 13/19] update CMakeLists.txt --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 351234f..72b1d7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,8 +14,10 @@ set(CMAKE_CXX_STANDARD 20) # set(CMAKE_CXX_STANDARD_REQUIRED ON) if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") - message("CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") endif() +message("CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}") set(CMAKE_INCLUDE_CURRENT_DIR ON) include_directories(include) From c3e367bdb6d4ba098517c8055b62de3bc9744918 Mon Sep 17 00:00:00 2001 From: yffbit Date: Wed, 29 May 2024 23:43:33 +0800 Subject: [PATCH 14/19] update main.yml --- .github/workflows/main.yml | 6 +++--- include/tools/logger.h | 2 +- mainwindow.cpp | 2 +- mainwindow.h | 2 +- src/tools/CommandLineTool.cpp | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d552122..b368f0c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,13 +32,13 @@ jobs: # Runs a set of commands using the runners shell - name: install dependencies run: | - sudo apt install -y qt5-default qtbase5-dev qt5-qmake build-essential wget + sudo apt install -y qt5-default qtbase5-dev qt5-qmake build-essential wget cmake # Runs a set of commands using the runners shell - name: make appimage run: | - ls - # ./build-AppImage.sh + ls + cmake --version cmake -DCMAKE_BUILD_TYPE=Release -S . -B build make -C build -j diff --git a/include/tools/logger.h b/include/tools/logger.h index 8ecde8e..b121be6 100644 --- a/include/tools/logger.h +++ b/include/tools/logger.h @@ -13,7 +13,7 @@ string get_localtime(); class Logger { public: - Logger(bool cmd, const char *path, const char *mode = "w", bool timestamp = false, bool new_line = true, int period = 10) + Logger(bool cmd, const char *path, const char *mode = "w+", bool timestamp = false, bool new_line = true, int period = 10) :cmd(cmd), timestamp(timestamp), new_line(new_line), period(period) { if(path) { file = fopen(path, mode); diff --git a/mainwindow.cpp b/mainwindow.cpp index 07a8ec7..3ed6a4b 100644 --- a/mainwindow.cpp +++ b/mainwindow.cpp @@ -18,7 +18,7 @@ MainWindow::MainWindow(QWidget *parent) : connect(this->ui->actionimport, &QAction::triggered, this, &MainWindow::on_actionimport_triggered); connect(this->ui->actionexport, &QAction::triggered, this, &MainWindow::on_actionexport_triggered); connect(this->ui->actionclear_all, &QAction::triggered, this, &MainWindow::on_actionclear_all_triggered); - logger = new QLogger((get_localtime() + ".txt").c_str(), "w", false, 1); + logger = new QLogger((get_localtime() + ".txt").c_str(), "w+", false, 1); clt.logger = logger; qSolverJob = new QSolverJob; qSolverJob->clt = &clt; diff --git a/mainwindow.h b/mainwindow.h index ca902d0..9eb56a0 100644 --- a/mainwindow.h +++ b/mainwindow.h @@ -17,7 +17,7 @@ class QLogger : public Logger { public: - QLogger(const char *path, const char *mode = "w", bool timestamp = false, int period = 10):Logger(false, path, mode, timestamp, true, period) {} + QLogger(const char *path, const char *mode = "w+", bool timestamp = false, int period = 10):Logger(false, path, mode, timestamp, true, period) {} virtual void log(const char *format, ...) { if(timestamp) log_time(); va_list args; diff --git a/src/tools/CommandLineTool.cpp b/src/tools/CommandLineTool.cpp index 31a3994..f799a8a 100644 --- a/src/tools/CommandLineTool.cpp +++ b/src/tools/CommandLineTool.cpp @@ -280,7 +280,7 @@ int cmd_api(string &input_file, string &resource_dir, string &mode, string &log_ resource_dir = "./resources"; } if(log_file.empty()) log_file = get_localtime() + ".txt"; - Logger logger(true, log_file.c_str(), "w", true, true, 1); + Logger logger(true, log_file.c_str(), "w+", true, true, 1); PokerMode poker_mode = PokerMode::UNKNOWN; if(mode.empty() || mode == "holdem") poker_mode = PokerMode::HOLDEM; else if(mode == "shortdeck") poker_mode = PokerMode::SHORTDECK; From ba31156b8bd04ec6d6712f6b8c09188e79222b0e Mon Sep 17 00:00:00 2001 From: yffbit Date: Wed, 29 May 2024 23:54:15 +0800 Subject: [PATCH 15/19] update main.yml --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b368f0c..fac4d3e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: # Runs a set of commands using the runners shell - name: install dependencies run: | - sudo apt install -y qt5-default qtbase5-dev qt5-qmake build-essential wget cmake + sudo apt install -y qt5-default qtbase5-dev qt5-qmake qttools5-dev build-essential wget cmake # Runs a set of commands using the runners shell - name: make appimage From 6742c5638f959b761e725e34d02a53cb4d245bfb Mon Sep 17 00:00:00 2001 From: yffbit Date: Sat, 1 Jun 2024 18:33:17 +0800 Subject: [PATCH 16/19] get_evs --- include/solver/cuda_cfr.h | 6 +- include/solver/cuda_func.h | 2 +- include/solver/slice_cfr.h | 26 +++- include/tools/CommandLineTool.h | 2 +- src/Card.cpp | 2 + src/solver/cuda_cfr.cu | 18 +-- src/solver/cuda_func.cu | 2 +- src/solver/slice_cfr.cpp | 262 +++++++++++++++++++++++++------- 8 files changed, 240 insertions(+), 80 deletions(-) diff --git a/include/solver/cuda_cfr.h b/include/solver/cuda_cfr.h index 41ac019..42a01b7 100644 --- a/include/solver/cuda_cfr.h +++ b/include/solver/cuda_cfr.h @@ -66,14 +66,14 @@ class CudaCFR : public SliceCFR { size_t init_leaf_node(); void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset); size_t init_strength_table(); - virtual void step(int iter, int player, bool best_cfv=false); + virtual void step(int iter, int player, int task); virtual void leaf_cfv(int player); int block_size(int size) {// ceil return (size + LANE_SIZE - 1) / LANE_SIZE; } void clear_prob_sum(int len); - virtual void _reach_prob(int player, bool best_cfv=false); - virtual void _rm(int player, bool best_cfv=false); + virtual void _reach_prob(int player, bool avg_strategy); + virtual void _rm(int player, bool avg_strategy); virtual void clear_data(int player); virtual void clear_root_cfv(); virtual void post_process(); diff --git a/include/solver/cuda_func.h b/include/solver/cuda_func.h index 1c7cb4c..b4c807d 100644 --- a/include/solver/cuda_func.h +++ b/include/solver/cuda_func.h @@ -18,6 +18,6 @@ extern __global__ void fold_cfv_kernel(int player, int size, CudaLeafNode *node, extern __global__ void sd_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *my_card, int *opp_card, int n_card); extern __global__ void best_cfv_kernel(Node *node, int size, int n_hand); extern __global__ void cfv_kernel(Node *node, int size, int n_hand); -extern __global__ void updata_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef); +extern __global__ void discount_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef); #endif // _CUDA_FUNC_H_ \ No newline at end of file diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index 8c6e503..3837729 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -34,7 +34,7 @@ using std::mutex; #define N_LEAF_TYPE 2 #define N_TYPE 5 -#define N_TASK_SIZE 5 + #define two_card_hash(card1, card2) ((1LL<<(card1)) | (1LL<<(card2))) #define tril_idx(r, c) (((r)*((r)-1)>>1)+(c)) // r>c>=0 @@ -49,12 +49,18 @@ using std::mutex; #define code_idx1(i) (((i)+1)<<16) #define decode_idx1(x) ((((x)>>16)&0xff) - 1) +#define EXP_TASK 0 +#define CFV_TASK 1 +#define CFR_TASK 2 + struct Node { int n_act = 0;// 动作数 int parent_offset = -1;// 本节点对应的父节点数据reach_prob的偏移量 float *parent_cfv = nullptr; // mutex *mtx = nullptr; float *data = nullptr;// cfv,regret_sum,strategy_sum,reach_prob,sum + float *opp_prob = nullptr; + size_t board = 0LL; }; struct LeafNode { float *reach_prob[N_PLAYER] = {nullptr,nullptr}; @@ -116,6 +122,8 @@ class SliceCFR : public Solver { int steps = 0, interval = 0, n_card = N_CARD, min_card = 0; int init_round = 0; int dfs_idx = 0;// 先序遍历 + unordered_map> node_idx; + int combination_num[N_ROUND-1] {1,N_CARD,N_CARD*N_CARD}; size_t init_board = 0; int hand_size[N_PLAYER]; float norm = 1;// 根节点概率归一化系数 @@ -162,9 +170,9 @@ class SliceCFR : public Solver { size_t init_strength_table(); void dfs(shared_ptr node, int parent_act=-1, int parent_dfs_idx=-1, int parent_p0_act=-1, int parent_p0_idx=-1, int parent_p1_act=-1, int parent_p1_idx=-1, int cnt0=0, int cnt1=0, int info=0); void init_poss_card(Deck& deck, size_t board); - virtual void step(int iter, int player, bool best_cfv=false); + virtual void step(int iter, int player, int task); virtual void leaf_cfv(int player); - void fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, size_t board); + void fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, float val, size_t board); void sd_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, int idx); void append_node_idx(int p_idx, int act_idx, int player, int cpu_node_idx); vector> pre_leaf_node_map;// [dfs_idx,act_idx] @@ -179,13 +187,19 @@ class SliceCFR : public Solver { // int mtx_idx = N_PLAYER; vector> strength; size_t _estimate_tree_size(shared_ptr node); - virtual void _reach_prob(int player, bool best_cfv=false); - virtual void _rm(int player, bool best_cfv=false); + virtual void _reach_prob(int player, bool avg_strategy); + virtual void _rm(int player, bool avg_strategy); virtual void clear_data(int player); virtual void clear_root_cfv(); virtual void post_process() {} json reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info); - virtual vector> get_avg_strategy(int idx); + virtual vector> get_avg_strategy(int idx);// [n_hand,n_act] + virtual vector> get_ev(int idx);// [n_hand,n_act] + bool print_exploitability(int iter, Timer &timer); + void cfv_to_ev(); + void cfv_to_ev(Node *node, int player); + void get_prob_sum(vector &prob_sum, float &sum, int player, float *reach_prob, size_t board); + void output_data(ActionNode *node, vector &cards, vector>> &out, bool ev); }; #endif // _SLICE_CFR_H_ diff --git a/include/tools/CommandLineTool.h b/include/tools/CommandLineTool.h index e302120..f68cde4 100644 --- a/include/tools/CommandLineTool.h +++ b/include/tools/CommandLineTool.h @@ -78,7 +78,7 @@ class CommandLineTool{ string board; string res_file; string algorithm = "discounted_cfr"; - float accuracy; + float accuracy = 0.1; int max_iteration=100; bool use_isomorphism=0; int use_halffloats=0; diff --git a/src/Card.cpp b/src/Card.cpp index 8e9f04d..f06c089 100644 --- a/src/Card.cpp +++ b/src/Card.cpp @@ -26,6 +26,7 @@ const string& Card::getCard() { return this->card; } +// rank * 4 + suit,[13,4] int Card::getCardInt() { return this->card_int; } @@ -39,6 +40,7 @@ int Card::card2int(Card card) { return strCard2int(card.getCard()); } +// rank * 4 + suit,[13,4] int Card::strCard2int(const string &card) { char rank = card.at(0); char suit = card.at(1); diff --git a/src/solver/cuda_cfr.cu b/src/solver/cuda_cfr.cu index 52bfc79..4c12942 100644 --- a/src/solver/cuda_cfr.cu +++ b/src/solver/cuda_cfr.cu @@ -250,23 +250,23 @@ size_t CudaCFR::estimate_tree_size() { return size; } -void CudaCFR::_reach_prob(int player, bool best_cfv) { +void CudaCFR::_reach_prob(int player, bool avg_strategy) { vector& offset = slice_offset[player]; int n = offset.size() - 1, size = 0, block = 0, n_hand = hand_size[player]; for(int i = 0; i < n; i++) { size = offset[i+1] - offset[i]; block = block_size(size); - if(best_cfv) reach_prob_avg_kernel<<>>(dev_nodes+offset[i], size, n_hand); + if(avg_strategy) reach_prob_avg_kernel<<>>(dev_nodes+offset[i], size, n_hand); else reach_prob_kernel<<>>(dev_nodes+offset[i], size, n_hand); cudaDeviceSynchronize(); } } -void CudaCFR::_rm(int player, bool best_cfv) { +void CudaCFR::_rm(int player, bool avg_strategy) { int size = node_cnt[N_LEAF_TYPE + player]; int block = block_size(size); Node *node = dev_nodes + slice_offset[player][0]; - if(best_cfv) rm_avg_kernel<<>>(node, size, hand_size[player]); + if(avg_strategy) rm_avg_kernel<<>>(node, size, hand_size[player]); else rm_kernel<<>>(node, size, hand_size[player]); cudaDeviceSynchronize(); } @@ -288,19 +288,19 @@ void CudaCFR::clear_root_cfv() { cudaDeviceSynchronize(); } -void CudaCFR::step(int iter, int player, bool best_cfv) { +void CudaCFR::step(int iter, int player, int task) { Timer timer; int opp = 1 - player, my_hand = hand_size[player], size = 0, block = 0; - _reach_prob(opp, best_cfv); + _reach_prob(opp, task != CFR_TASK); size_t t1 = timer.ms(true); leaf_cfv(player); size_t t2 = timer.ms(true); - if(!best_cfv) { + if(task == CFR_TASK) { size = n_player_node; block = block_size(size); - updata_data_kernel<<>>(dev_nodes, size, my_hand, pos_coef, neg_coef, coef); + discount_data_kernel<<>>(dev_nodes, size, my_hand, pos_coef, neg_coef, coef); cudaDeviceSynchronize(); } size_t t3 = timer.ms(true); @@ -308,7 +308,7 @@ void CudaCFR::step(int iter, int player, bool best_cfv) { for(int i = offset.size()-2; i >= 0; i--) { size = offset[i+1] - offset[i]; block = block_size(size); - if(best_cfv) best_cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); + if(task == EXP_TASK) best_cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); else cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); cudaDeviceSynchronize(); } diff --git a/src/solver/cuda_func.cu b/src/solver/cuda_func.cu index 0c02b6e..5e7a52c 100644 --- a/src/solver/cuda_func.cu +++ b/src/solver/cuda_func.cu @@ -264,7 +264,7 @@ __global__ void cfv_kernel(Node *node, int size, int n_hand) { for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv } -__global__ void updata_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef) { +__global__ void discount_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef) { int i = blockIdx.x * blockDim.x + threadIdx.x; if(i >= size) return; node += i; diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp index 7c43e91..99e7a1d 100644 --- a/src/solver/slice_cfr.cpp +++ b/src/solver/slice_cfr.cpp @@ -129,8 +129,39 @@ void cfv_up(Node *node, int n_hand) { } for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv } +// 只计算cfv +void cfv_up_avg(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 2; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + float *strategy_sum = cfv + (size << 1); + // mutex *mtx = node->mtx; + for(h = 0; h < n_hand; h++) { + val = 0; + if(cfv[sum_offset+h] == 0) { + for(i = h; i < size; i += n_hand) val += cfv[i]; + val /= n_act;// uniform strategy + } + else { + for(i = h; i < size; i += n_hand) { + val += cfv[i] * strategy_sum[i]; + } + val /= cfv[sum_offset+h]; + } + // cfv[sum_offset+h] = val; + // mtx->lock(); + // parent_cfv[h] += val;// 需要加锁 + // mtx->unlock(); + atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); + // for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum + // val = 0; + // for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); + // cfv[sum_offset+h] = val;// 求和 + } + // for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} // 在cfv_up前执行 -void updata_data(Node *node, int n_hand, float pos_coef, float neg_coef, float coef) { +void discount_data(Node *node, int n_hand, float pos_coef, float neg_coef, float coef) { int size = node->n_act * n_hand, i = 0; float *regret_sum = node->data + size, *strategy_sum = regret_sum + size; for(i = 0; i < size; i++) { @@ -139,6 +170,36 @@ void updata_data(Node *node, int n_hand, float pos_coef, float neg_coef, float c } } +void SliceCFR::cfv_to_ev() { + for(int i = 0; i < N_PLAYER; i++) { + vector& offset = slice_offset[i]; + #pragma omp parallel for + for(int j = offset[0]; j < offset.back(); j++) { + cfv_to_ev(player_node_ptr+j, i); + } + } +} +void SliceCFR::cfv_to_ev(Node *node, int player) { + float *opp_reach = node->opp_prob, *cfv = node->data; + size_t board = node->board; + vector opp_prob_sum(n_card, 0); + float prob_sum = 0, temp = 0; + get_prob_sum(opp_prob_sum, prob_sum, 1-player, opp_reach, board); + int n_hand = hand_size[player], size = node->n_act * n_hand, h = 0, i = 0; + int *same_hand = same_hand_ptr[player], *my_card = hand_card_ptr[player]; + size_t *my_hash = hand_hash_ptr[player]; + for(h = 0; h < n_hand; h++) { + temp = same_hand[h] != -1 ? opp_reach[same_hand[h]] : 0;// 重复计算的部分 + temp = prob_sum - opp_prob_sum[my_card[h]] - opp_prob_sum[my_card[h+n_hand]] + temp; + if((my_hash[h] & board) || temp == 0) { + for(i = h; i < size; i += n_hand) cfv[i] = 0; + } + else { + for(i = h; i < size; i += n_hand) cfv[i] /= temp; + } + } +} + // #define TIME_LOG #ifdef TIME_LOG atomic_ullong fold_time[16] = {0}, sd_time[16] = {0}; @@ -164,7 +225,7 @@ void SliceCFR::leaf_cfv(int player) { for(int j : vec[i].leaf_node_idx) { LeafNode &node = leaf_node[j]; if(j < sd_offset) { - fold_cfv(player, cfv, node.reach_prob[opp], my_hand, opp_hand, ev_ptr[j], node.info); + fold_cfv(player, cfv, node.reach_prob[opp], my_hand, ev_ptr[j], node.info); } else sd_cfv(player, cfv, node.reach_prob[opp], my_hand, opp_hand, ev_ptr[j], node.info); } @@ -176,23 +237,29 @@ void SliceCFR::leaf_cfv(int player) { // printf("leaf_cfv:%zd ms\n", timer.ms()); #endif } -void SliceCFR::fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, size_t board) { +void SliceCFR::get_prob_sum(vector &prob_sum, float &sum, int player, float *reach_prob, size_t board) { + float temp = 0; + int n_hand = hand_size[player], *hand_card = hand_card_ptr[player]; + size_t *hand_hash = hand_hash_ptr[player]; + for(int i = 0; i < n_hand; i++) { + if(hand_hash[i] & board) continue;// 对方手牌与公共牌冲突 + temp = reach_prob[i]; + prob_sum[hand_card[i]] += temp;// card1 + prob_sum[hand_card[i+n_hand]] += temp;// card2 + sum += temp; + } +} +void SliceCFR::fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, float val, size_t board) { #ifdef TIME_LOG Timer timer; #endif if(player != P0) val = -val; - size_t *my_hash = hand_hash_ptr[player], *opp_hash = hand_hash_ptr[1-player]; - int *my_card = hand_card_ptr[player], *opp_card = hand_card_ptr[1-player]; + size_t *my_hash = hand_hash_ptr[player]; + int *my_card = hand_card_ptr[player]; int *same_hand = same_hand_ptr[player], i = 0; vector opp_prob_sum(n_card, 0); float prob_sum = 0, temp = 0; - for(i = 0; i < opp_hand; i++) { - if(opp_hash[i] & board) continue;// 对方手牌与公共牌冲突 - temp = opp_reach[i]; - opp_prob_sum[opp_card[i]] += temp;// card1 - opp_prob_sum[opp_card[i+opp_hand]] += temp;// card2 - prob_sum += temp; - } + get_prob_sum(opp_prob_sum, prob_sum, 1-player, opp_reach, board); for(i = 0; i < my_hand; i++) { if(my_hash[i] & board) { // cfv[i] = 0;// 与公共牌冲突,cfv为0 @@ -281,7 +348,7 @@ size_t SliceCFR::init_leaf_node() { else { if(j == -1) info = 0; else if(k == -1) info = j; - else info = tril_idx(j, k); + else info = tril_idx(max(j, k), min(j, k)); } leaf_node[node_idx++].info = info; } @@ -414,19 +481,31 @@ size_t SliceCFR::init_player_node() { for(vector &nodes : slice[i]) {// 枚举slice slice_offset[i].push_back(mem_idx); for(int idx : nodes) {// 枚举node - DFSNode &node = dfs_node[idx]; - Node &target = player_node[mem_idx]; - target.n_act = node.n_act; - set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset); - size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); - target.data = (float *)malloc(size); - if(target.data == nullptr) throw runtime_error("malloc error"); - total += size; dfs_idx_map[idx] = mem_idx++; } } slice_offset[i].push_back(mem_idx); } + for(int idx = 0; idx < dfs_idx; idx++) { + if(dfs_idx_map[idx] == -1) continue; + DFSNode &node = dfs_node[idx]; + if(node.player != P0 && node.player != P1) throw runtime_error("unknow player"); + Node &target = player_node[dfs_idx_map[idx]]; + target.n_act = node.n_act; + set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset); + float *ptr = nullptr; + int offset = 0; + set_cfv_and_offset(node, 1-node.player, ptr, offset); + target.opp_prob = ptr + offset; + target.board = init_board; + int j = decode_idx0(node.info), k = decode_idx1(node.info); + if(j != -1) target.board |= 1LL << poss_card[j]; + if(k != -1) target.board |= 1LL << poss_card[k]; + size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); + target.data = (float *)malloc(size); + if(target.data == nullptr) throw runtime_error("malloc error"); + total += size; + } // mtx = vector(mtx_idx); // printf("%d,%d,%d\n", sizeof(mutex), mtx_idx, mtx_idx * sizeof(mutex)); // total += mtx_idx * sizeof(mutex); @@ -614,6 +693,22 @@ void SliceCFR::dfs(shared_ptr node, int parent_act, int parent_dfs int type = node->getType(), round = GameTreeNode::gameRound2int(node->getRound()), n_act = 0; if(type == GameTreeNode::ACTION) { shared_ptr act_node = dynamic_pointer_cast(node); + ActionNode *p = act_node.get(); + int r_offset = round - init_round; + if(node_idx.find(p) == node_idx.end()) node_idx[p] = vector(combination_num[r_offset], -1); + int j = decode_idx0(info), k = decode_idx1(info); + if(r_offset == 0) { + assert(j == -1 && k == -1); + node_idx[p][0] = curr_idx; + } + else if(r_offset == 1) { + assert(j != -1 && k == -1); + node_idx[p][poss_card[j]] = curr_idx; + } + else { + assert(r_offset == 2 && j != -1 && k != -1); + node_idx[p][poss_card[j]*N_CARD+poss_card[k]] = curr_idx; + } int player = act_node->getPlayer(); vector> children = act_node->getChildrens(); n_act = children.size(); @@ -638,13 +733,13 @@ void SliceCFR::dfs(shared_ptr node, int parent_act, int parent_dfs this->chance_node.push_back(curr_idx); if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) {// 需要发1张牌 dfs_node.emplace_back(CHANCE_PLAYER, n_act, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); - // 发牌信息编码,只有1张牌时,占用idx0,有2张牌时,索引较大的占用idx0,较小的占用idx1 + // 发牌信息编码,只有1张牌时,占用idx0,有2张牌时,占用idx0,idx1 int j = decode_idx0(info), new_info = 0; for(int i = 0, k = 0; i < n_act; i++, k++) {// 动作索引i,poss_card索引k if(j == -1) new_info = code_idx0(k);// 第一次发牌 else {// 第二次发牌,最多发两次牌 if(k == j) k++;// 两次选的一样,则第二次改成下一个 - new_info = code_idx0(max(j,k)) | code_idx1(min(j,k));// idx0为较大值 + new_info = code_idx0(j) | code_idx1(k); } dfs(children, i, curr_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx, cnt0, cnt1, new_info); } @@ -661,7 +756,7 @@ void SliceCFR::dfs(shared_ptr node, int parent_act, int parent_dfs for(int k = j+1; k < n_act; k++) { ev[SHOWDOWN_TYPE].push_back(val); leaf_node_dfs[SHOWDOWN_TYPE].push_back(dfs_idx++); - info = code_idx0(k) | code_idx1(j);// idx0为较大值 + info = code_idx0(j) | code_idx1(k); dfs_node.emplace_back(CHANCE_PLAYER, 0, i++, info, curr_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); } } @@ -699,10 +794,10 @@ void SliceCFR::init_poss_card(Deck& deck, size_t board) { print_array(chance_den, N_ROUND); } -void SliceCFR::_reach_prob(int player, bool best_cfv) { +void SliceCFR::_reach_prob(int player, bool avg_strategy) { vector& offset = slice_offset[player]; int n = offset.size(), n_hand = hand_size[player]; - node_func func = best_cfv ? reach_prob_avg : reach_prob; + node_func func = avg_strategy ? reach_prob_avg : reach_prob; for(int i = 1; i < n; i++) { #pragma omp parallel for for(int j = offset[i-1]; j < offset[i]; j++) { @@ -710,8 +805,8 @@ void SliceCFR::_reach_prob(int player, bool best_cfv) { } } } -void SliceCFR::_rm(int player, bool best_cfv) { - node_func func = best_cfv ? rm_avg : rm; +void SliceCFR::_rm(int player, bool avg_strategy) { + node_func func = avg_strategy ? rm_avg : rm; int s = slice_offset[player][0], e = slice_offset[player].back(), n_hand = hand_size[player]; #pragma omp parallel for for(int i = s; i < e; i++) { @@ -733,59 +828,66 @@ void SliceCFR::clear_root_cfv() { memset(root_cfv_ptr[P0], 0, size); } +bool SliceCFR::print_exploitability(int iter, Timer &timer) { + vector res = exploitability(); + logger->log("%d:%.3fs", iter, timer.ms()/1000.0); + float avg = (res[0] + res[1]) / 2; + logger->log("%d:%f %f %f", iter, res[0], res[1], avg); + return avg <= tol; +} + void SliceCFR::train() { init(); if(!init_succ) return; - size_t start = timeSinceEpochMillisec(), total = 0; Timer timer; clear_data(P0); clear_data(P1); // _rm(P0, false); // _rm(P1, false); // _reach_prob(P0, false); - vector res = exploitability(); - logger->log("0:%f %f %f", res[0], res[1], (res[0]+res[1])/2); + print_exploitability(0, timer); // 计算exploitability后,双方的rm和p0的reach_prob已经恢复 - pos_coef = neg_coef = coef = 0; + // pos_coef = neg_coef = coef = 0; double temp = 0; - int cnt = 0, iter = 1; - for(iter = 1; iter <= steps; iter++) { - clear_root_cfv(); - for(int player = P0; player < N_PLAYER; player++) { - step(iter, player, false); - } + int cnt = 0, iter = 0; + while(iter < steps) { temp = pow(iter, alpha); pos_coef = temp / (temp + 1); temp = pow(iter, beta); neg_coef = temp / (temp + 1); // neg_coef = 0.5; coef = pow((float)iter/(iter+1), gamma); + + clear_root_cfv(); + for(int player = P0; player < N_PLAYER; player++) { + step(iter, player, CFR_TASK); + } + iter++; if((++cnt) == interval) { cnt = 0; - res = exploitability(); - total = timeSinceEpochMillisec() - start; - logger->log("%d:%.3f,%.3fs", iter, timer.ms()/1000.0, total/1000.0); - temp = (res[0] + res[1]) / 2; - logger->log("%d:%f %f %f", iter, res[0], res[1], temp); - if(temp <= tol) break; + if(print_exploitability(iter, timer)) break; } if(stop_flag) break; } if(cnt) { - res = exploitability(); - total = timeSinceEpochMillisec() - start; - logger->log("%d:%.3f,%.3fs", iter, timer.ms()/1000.0, total/1000.0); - logger->log("%d:%f %f %f", iter, res[0], res[1], (res[0]+res[1])/2); + print_exploitability(iter, timer); + } + logger->log("collecting statics"); + for(int player = P0; player < N_PLAYER; player++) { + _rm(1-player, true); + step(iter, player, CFV_TASK); } + cfv_to_ev(); + logger->log("statics collected"); } -// player到达概率已经计算好 -void SliceCFR::step(int iter, int player, bool best_cfv) { +// 执行更新任务时,player到达概率需要提前计算好 +void SliceCFR::step(int iter, int player, int task) { #ifdef TIME_LOG size_t start = timeSinceEpochMillisec(), end = 0; #endif int opp = 1 - player, my_hand = hand_size[player]; - _reach_prob(opp, best_cfv); + _reach_prob(opp, task != CFR_TASK); #ifdef TIME_LOG end = timeSinceEpochMillisec(); size_t t1 = end - start; @@ -800,10 +902,10 @@ void SliceCFR::step(int iter, int player, bool best_cfv) { #endif vector& offset = slice_offset[player]; - if(!best_cfv) { + if(task == CFR_TASK) { #pragma omp parallel for for(int j = offset[0]; j < offset.back(); j++) { - updata_data(player_node_ptr+j, my_hand, pos_coef, neg_coef, coef); + discount_data(player_node_ptr+j, my_hand, pos_coef, neg_coef, coef); } } #ifdef TIME_LOG @@ -812,7 +914,12 @@ void SliceCFR::step(int iter, int player, bool best_cfv) { start = end; #endif - node_func func = best_cfv ? best_cfv_up : cfv_up; + node_func func; + switch(task) { + case EXP_TASK:{func = best_cfv_up;break;} + case CFV_TASK:{func = cfv_up_avg;break;} + default:func = cfv_up; + } for(int i = offset.size()-1; i > 0; i--) { #pragma omp parallel for for(int j = offset[i-1]; j < offset[i]; j++) { @@ -838,7 +945,7 @@ vector SliceCFR::exploitability() { #ifdef TIME_LOG size_t t1 = timeSinceEpochMillisec() - start; #endif - step(0, player, true); + step(0, player, EXP_TASK); #ifdef TIME_LOG start = timeSinceEpochMillisec(); #endif @@ -868,10 +975,48 @@ json SliceCFR::dumps(bool with_status, int depth) {// depth:max_round return std::move(ans); } vector>> SliceCFR::get_strategy(shared_ptr node, vector cards) { - return {}; + vector>> ans(N_CARD, vector>(N_CARD)); + output_data(node.get(), cards, ans, false); + return std::move(ans); } vector>> SliceCFR::get_evs(shared_ptr node, vector cards) { - return {}; + vector>> ans(N_CARD, vector>(N_CARD)); + output_data(node.get(), cards, ans, true); + return std::move(ans); +} +void SliceCFR::output_data(ActionNode *node, vector &cards, vector>> &out, bool ev) { + int r_offset = GameTreeNode::gameRound2int(node->getRound()) - init_round; + if(cards.size() != r_offset || r_offset > 2) throw runtime_error("chance_cards error"); + int idx = 0; + size_t board = init_board; + if(r_offset >= 1) { + idx = cards[0].getCardInt(); + board |= 1LL << idx; + } + if(r_offset == 2) { + idx = idx * N_CARD + cards[1].getCardInt(); + board |= 1LL << cards[1].getCardInt(); + } + vector> data; + if(ev) data = get_ev(node_idx.at(node)[idx]); + else data = get_avg_strategy(node_idx.at(node)[idx]); + int player = node->getPlayer(), n_hand = hand_size[player], *card = hand_card_ptr[player]; + size_t *ptr = hand_hash_ptr[player]; + for(int h = 0; h < n_hand; h++) { + if(!cards_valid(ptr[h], board)) continue; + out[card[h]+min_card][card[h+n_hand]+min_card].swap(data[h]); + } +} +vector> SliceCFR::get_ev(int idx) { + Node &node = player_node[dfs_idx_map[idx]]; + int n_hand = hand_size[dfs_node[idx].player], n_act = node.n_act; + int i = 0, h = 0, j = 0; + float *cfv = node.data; + vector> ev(n_hand, vector(n_act));// [n_hand,n_act] + for(j = 0; j < n_act; j++) { + for(h = 0; h < n_hand; h++) ev[h][j] = cfv[i++]; + } + return std::move(ev); } vector> SliceCFR::get_avg_strategy(int idx) { Node &node = player_node[dfs_idx_map[idx]]; @@ -889,7 +1034,7 @@ vector> SliceCFR::get_avg_strategy(int idx) { for(j = 0, i = h; j < n_act; j++, i += n_hand) strategy[h][j] = strategy_sum[i] / sum; } } - return strategy; + return std::move(strategy); } json SliceCFR::reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info) { int curr_idx = idx++; @@ -939,7 +1084,6 @@ json SliceCFR::reConvertJson(const shared_ptr& node, int depth, in if(j == -1) new_info = code_idx0(k);// 第一次发牌 else {// 第二次发牌,最多发两次牌 if(k == j) k++;// 两次选的一样,则第二次改成下一个 - // new_info = code_idx0(max(j,k)) | code_idx1(min(j,k));// idx0为较大值 } json child = reConvertJson(children, depth, max_depth, idx, new_info); if(depth < max_depth) ans["dealcards"][Card::intCard2Str(poss_card[k])] = child; From f160411c19fca59e9465e39912801437b04cec92 Mon Sep 17 00:00:00 2001 From: yffbit Date: Sat, 1 Jun 2024 20:13:23 +0800 Subject: [PATCH 17/19] update --- include/solver/cuda_cfr.h | 2 ++ include/solver/slice_cfr.h | 2 +- src/solver/cuda_cfr.cu | 6 +++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/include/solver/cuda_cfr.h b/include/solver/cuda_cfr.h index 42a01b7..6b3b010 100644 --- a/include/solver/cuda_cfr.h +++ b/include/solver/cuda_cfr.h @@ -78,6 +78,8 @@ class CudaCFR : public SliceCFR { virtual void clear_root_cfv(); virtual void post_process(); virtual vector> get_avg_strategy(int idx); + virtual vector> get_ev(int idx); + virtual void cfv_to_ev(); }; #endif // _CUDA_CFR_H_ diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h index 3837729..231f09f 100644 --- a/include/solver/slice_cfr.h +++ b/include/solver/slice_cfr.h @@ -196,7 +196,7 @@ class SliceCFR : public Solver { virtual vector> get_avg_strategy(int idx);// [n_hand,n_act] virtual vector> get_ev(int idx);// [n_hand,n_act] bool print_exploitability(int iter, Timer &timer); - void cfv_to_ev(); + virtual void cfv_to_ev(); void cfv_to_ev(Node *node, int player); void get_prob_sum(vector &prob_sum, float &sum, int player, float *reach_prob, size_t board); void output_data(ActionNode *node, vector &cards, vector>> &out, bool ev); diff --git a/src/solver/cuda_cfr.cu b/src/solver/cuda_cfr.cu index 4c12942..dcd3a3b 100644 --- a/src/solver/cuda_cfr.cu +++ b/src/solver/cuda_cfr.cu @@ -343,4 +343,8 @@ vector> CudaCFR::get_avg_strategy(int idx) { } } return strategy; -} \ No newline at end of file +} +vector> CudaCFR::get_ev(int idx) { + return {}; +} +void CudaCFR::cfv_to_ev() {} From 191c88f44eea2a02f959de27e1c4924932e1e608 Mon Sep 17 00:00:00 2001 From: yffbit Date: Sun, 14 Jul 2024 17:51:33 +0800 Subject: [PATCH 18/19] update main.yml --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fac4d3e..92ab062 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,7 +18,7 @@ jobs: # This workflow contains a single job called "build" build: # The type of runner that the job will run on - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 # Steps represent a sequence of tasks that will be executed as part of the job steps: From dab6d4d08ab9db29ed35909ed5ba745647def8df Mon Sep 17 00:00:00 2001 From: yffbit Date: Sun, 14 Jul 2024 17:58:03 +0800 Subject: [PATCH 19/19] update main.yml --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 92ab062..82549cd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: # Runs a set of commands using the runners shell - name: install dependencies run: | - sudo apt install -y qt5-default qtbase5-dev qt5-qmake qttools5-dev build-essential wget cmake + sudo apt install -y qtbase5-dev qt5-qmake qttools5-dev build-essential wget cmake # Runs a set of commands using the runners shell - name: make appimage