Skip to content

Commit

Permalink
update for CxxWrap 0.14 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremiah Lewis committed Mar 1, 2024
1 parent 1f2364e commit 511dd9d
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 13 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name = "OpenSpiel"
uuid = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
authors = ["Jun Tian <[email protected]>"]
version = "0.1.5"
version = "0.2.0"

[deps]
CxxWrap = "1f15a43c-97ca-5a2a-ae31-89f07a497df4"
OpenSpiel_jll = "bd10a763-4654-5023-a028-c4918c6cd33e"

[compat]
CxxWrap = "0.12, 0.13, 0.14"
CxxWrap = "0.14"
OpenSpiel_jll = "1"
julia = "1.6"

Expand Down
2 changes: 1 addition & 1 deletion src/OpenSpiel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using OpenSpiel_jll

import CxxWrap:argument_overloads

@wrapmodule(libspieljl)
@wrapmodule(OpenSpiel_jll.get_libspieljl_path)

include("patch.jl")

Expand Down
84 changes: 79 additions & 5 deletions src/patch.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
Base.show(io::IO, g::CxxWrap.StdLib.SharedPtrAllocated{Game}) = print(io, to_string(g))
Base.show(io::IO, s::CxxWrap.StdLib.UniquePtrAllocated{State}) = print(io, to_string(s))
Base.show(io::IO, g::CxxWrap.StdLib.SharedPtrAllocated{Game}) = print(io, to_string(g[]))
Base.show(io::IO, s::CxxWrap.StdLib.UniquePtrAllocated{State}) = print(io, to_string(s[]))
Base.show(io::IO, gp::Union{GameParameterAllocated, GameParameterDereferenced}) = print(io, to_repr_string(gp))

function Base.hash(s::CxxWrap.CxxWrapCore.SmartPointer{T}, h::UInt) where {T<:Union{Game,State}}
hash(to_string(s), h)
hash(to_string(s[]), h)
end

function Base.:(==)(s::CxxWrap.CxxWrapCore.SmartPointer{T}, ss::CxxWrap.CxxWrapCore.SmartPointer{T}) where {T<:Union{Game, State}}
to_string(s) == to_string(ss)
to_string(s[]) == to_string(ss[])
end

GameParameter(x::Int) = GameParameter(Ref(Int32(x)))

Base.copy(s::CxxWrap.StdLib.UniquePtrAllocated{State}) = deepcopy(s)
Base.deepcopy(s::CxxWrap.StdLib.UniquePtrAllocated{State}) = clone(s)
Base.deepcopy(s::CxxWrap.StdLib.UniquePtrAllocated{State}) = clone(s[])
Base.reshape(s::CxxWrap.StdLib.StdVectorAllocated, dims::Int32...) = reshape(s, Int.(dims))

if Sys.KERNEL == :Linux
Expand Down Expand Up @@ -75,3 +75,77 @@ function load_game_as_turn_based(s::Union{String, CxxWrap.StdLib.StdStringAlloca
_load_game_as_turn_based(s, StdMap{StdString, GameParameter}(ps))
end
end

is_chance_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_chance_node(state[])

new_initial_state(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = new_initial_state(game[])

legal_actions(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = legal_actions(state[])

child(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i::Int64) = child(state[], i)

is_terminal(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_terminal(state[])

information_state_string(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = information_state_string(state[])

get_uniform_policy(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = get_uniform_policy(game[])

record_batched_trajectories(game::CxxWrap.StdLib.SharedPtrAllocated{Game}, p::CxxWrap.StdLib.StdVectorAllocated{TabularPolicy}, m::StdMapAllocated{StdString, Int32}, i::Int64, b::Bool, i2::Int64, i3::Int64) = record_batched_trajectories(game[], p, m, i, b, i2, i3)

expected_returns(state::CxxWrap.StdLib.UniquePtrAllocated{State}, policy::CxxWrap.StdLib.SharedPtrAllocated{Policy}, i::Int64) = expected_returns(state[], policy[], i)

exploitability(game::CxxWrap.StdLib.SharedPtrAllocated{Game}, policy::CxxWrap.StdLib.SharedPtrAllocated{Policy}) = exploitability(game[], policy[])

current_player(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = current_player(state[])

action_to_string(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i1, i2) = action_to_string(state[], i1, i2)

apply_action(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i::AbstractVector{<:Number}) = apply_action(state[], i)

apply_action(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i::Number) = apply_action(state[], i)

restart_at(b::MCTSBotAllocated, s::CxxWrap.StdLib.UniquePtrAllocated{State}) = restart_at(b, s[])

best_child(root::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = best_child(root[])

get_outcome(root::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = get_outcome(root[])

get_player(p::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = get_player(p[])

get_children(root::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = get_children(root[])

is_simultaneous_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_simultaneous_node(state[])

step(bot, state::CxxWrap.StdLib.UniquePtrAllocated{State}) = step(bot, state[])

chance_outcomes(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = chance_outcomes(state[])

returns(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = returns(state[])

min_utility(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = min_utility(game[])

max_utility(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = max_utility(game[])

serialize_game_and_state(game::CxxWrap.StdLib.SharedPtrAllocated{Game}, state::CxxWrap.StdLib.UniquePtrAllocated{State}) = serialize_game_and_state(game[], state[])

is_mean_field_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_mean_field_node(state[])

legal_actions(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i) = legal_actions(state[], i)

history(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = history(state[])

is_player_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_player_node(state[])

num_players(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = num_players(game[])

distribution_support(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = distribution_support(state[])

get_type(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = get_type(game[])

update_distribution(state::CxxWrap.StdLib.UniquePtrAllocated{State}, dist::CxxWrap.StdLib.StdVectorAllocated{Float64}) = update_distribution(state[], dist)

num_cols(game::CxxWrap.StdLib.SharedPtrAllocated{MatrixGame}) = num_cols(game[])

num_rows(game::CxxWrap.StdLib.SharedPtrAllocated{MatrixGame}) = num_rows(game[])

extensive_to_matrix_game(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = extensive_to_matrix_game(game[])
6 changes: 3 additions & 3 deletions test/bots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@testset "MCTSBot" begin
UCT_C = 2.

init_bot(game, max_simulations, evaluator) = MCTSBot(game, evaluator, UCT_C, max_simulations, 5, true, 42, false, UCT, 0., 0.)
init_bot(game, max_simulations, evaluator) = MCTSBot(game[], evaluator, UCT_C, max_simulations, 5, true, 42, false, UCT, 0., 0.)

@testset "can play tic_tac_toe" begin
game = load_game("tic_tac_toe")
Expand Down Expand Up @@ -51,8 +51,8 @@
apply_action(state, get_action_by_str(state, action_str))
end
evaluator = random_rollout_evaluator_factory(20, 42)
bot = MCTSBot(game, evaluator, UCT_C, 10000, 10, true, 42, false, UCT, 0., 0.)
mcts_search(bot, state), state
bot = MCTSBot(game[], evaluator, UCT_C, 10000, 10, true, 42, false, UCT, 0., 0.)
mcts_search(bot, state[]), state[]
end

@testset "solve draw" begin
Expand Down
4 changes: 2 additions & 2 deletions test/cfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test_exploitability_kuhn_poker(game, policy) = @test exploitability(game, policy

@testset "CFRSolver" begin
game = load_game("kuhn_poker")
solver = CFRSolver(game)
solver = CFRSolver(game[])
for _ in 1:300
evaluate_and_update_policy(solver)
end
Expand All @@ -26,7 +26,7 @@ end

@testset "CFRPlusSolver" begin
game = load_game("kuhn_poker")
solver = CFRPlusSolver(game)
solver = CFRPlusSolver(game[])
for _ in 1:200
evaluate_and_update_policy(solver)
end
Expand Down

0 comments on commit 511dd9d

Please sign in to comment.