This repository has been archived by the owner on May 2, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
play_gomoku_az.py
63 lines (50 loc) · 2.19 KB
/
play_gomoku_az.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import logging
import os
import sys
import torch
import yaml
from alphazero.agents.alphazero import AlphaZeroArgMaxAgent
from alphazero.alphazero.mcts import MonteCarloTreeSearch
from alphazero.alphazero.nn_modules.nets import dual_resnet
from alphazero.alphazero.state_encoders.gomoku_state_encoder import GomokuStateEncoder
from alphazero.games.gomoku import GomokuGame, GomokuPlayer, GomokuMove
FORMAT = '%(asctime)s - %(name)-15s - %(levelname)s - %(message)s'
logging.basicConfig(stream=sys.stderr, level=logging.INFO,
format=FORMAT, datefmt='%m/%d/%Y %I:%M:%S %p')
with open('gomoku.yaml', 'r') as f:
config = yaml.safe_load(f)
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
def read_move(player: GomokuPlayer) -> GomokuMove:
x, y = input(f"{player.name} move: ").split()
x, y = int(x), int(y)
return GomokuMove(x, y)
if __name__ == '__main__':
game = GomokuGame(config['game_size'])
state_encoder = GomokuStateEncoder(config['device'], num_history=config['num_history'])
net = dual_resnet(game, config)
mcts = MonteCarloTreeSearch(game=game,
state_encoder=state_encoder,
nn=net,
config=config)
# net.load_state_dict(torch.load(os.path.join(config['log_dir'], 'best.pth')))
net.load_state_dict(torch.load(os.path.join('pretrained', 'gomoku_playable.pth')))
net.eval()
agent = AlphaZeroArgMaxAgent(game, state_encoder, net, config)
while not game.is_over:
game.show_board()
# print(f"current state score by eval func: {agent.eval_fn(game.state, agent.player)}")
if game.current_player == GomokuPlayer.BLACK:
move = read_move(game.current_player)
while not game.state.is_legal_move(move):
print("Illegal move, try again")
move = read_move(game.current_player)
else:
move = agent.select_move(game.state)
print(f"Agent {game.current_player!s} move: {move}")
game.play(move)
print("--- GAME OVER ---")
game.show_board()
if game.winner:
print(f"{game.winner.name} wins!")
else:
print("tie :(")