diff --git a/Project.toml b/Project.toml index 4abd2fa..485ab2e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ConstraintDomains" uuid = "5800fd60-8556-4464-8d61-84ebf7a0bedb" authors = ["Jean-François Baffier"] -version = "0.3.15" +version = "0.4.0" [deps] ConstraintCommons = "e37357d9-0691-492f-a822-e5ea6a920954" @@ -11,7 +11,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" [compat] -ConstraintCommons = "0.2" +ConstraintCommons = "0.3" Intervals = "1" PatternFolds = "0.2" StatsBase = "0.34" diff --git a/src/common.jl b/src/common.jl index 93e68a9..15dfecd 100644 --- a/src/common.jl +++ b/src/common.jl @@ -64,7 +64,7 @@ Convert various arguments into valid domains format. to_domains(domain_sizes::Vector{Int}) = map(ds -> domain(0:ds), domain_sizes) function to_domains(X, ds::Int = δ_extrema(X) + 1) - d = domain(0:ds-1) + d = domain(0:(ds-1)) return fill(d, length(first(X))) end diff --git a/src/explore.jl b/src/explore.jl index 6c93895..8a0db35 100644 --- a/src/explore.jl +++ b/src/explore.jl @@ -1,4 +1,4 @@ -struct ExploreSettings +mutable struct ExploreSettings complete_search_limit::Int max_samplings::Int search::Symbol @@ -42,22 +42,52 @@ function ExploreSettings( return ExploreSettings(complete_search_limit, max_samplings, search, solutions_limit) end -struct ExplorerState{T} +abstract type AbstractExplorerState end + +struct CompleteExplorerState{N,T} <: AbstractExplorerState + best::Vector{T} + solutions::Vector{NTuple{N,T}} + non_solutions::Vector{NTuple{N,T}} + + CompleteExplorerState{N,T}() where {N,T} = + new{N,T}(Vector{T}(), Vector{NTuple{N,T}}(), Vector{NTuple{N,T}}()) +end + +function explorer_state(domains, ::Val{:complete}) + return CompleteExplorerState{length(domains),Union{map(eltype, domains)...}}() +end + +struct PartialExplorerState{T} <: AbstractExplorerState best::Vector{T} solutions::Set{Vector{T}} non_solutions::Set{Vector{T}} - ExplorerState{T}() where {T} = new{T}([], Set{Vector{T}}(), Set{Vector{T}}()) + PartialExplorerState{T}() where {T} = + new{T}(Vector{T}(), Set{Vector{T}}(), Set{Vector{T}}()) +end +function explorer_state(domains, ::Val{:partial}) + return PartialExplorerState{Union{map(eltype, domains)...}}() end -ExplorerState(domains) = ExplorerState{Union{map(eltype, domains)...}}() - -mutable struct Explorer{F1<:Function,D<:AbstractDomain,F2<:Union{Function,Nothing},T} +mutable struct Explorer{ + F1<:Function, + D<:AbstractDomain, + F2<:Union{Function,Nothing}, + S<:AbstractExplorerState, +} concepts::Dict{Int,Tuple{F1,Vector{Int}}} domains::Dict{Int,D} objective::F2 settings::ExploreSettings - state::ExplorerState{T} + state::S + + function Explorer(concepts, domains, objective, settings, state) + F1 = isempty(concepts) ? Function : typeof(concepts).parameters[2].parameters[1] + D = isempty(domains) ? AbstractDomain : typeof(domains).parameters[2] + F2 = typeof(objective) + S = typeof(state) + return new{F1,D,F2,S}(concepts, domains, objective, settings, state) + end end """ @@ -88,13 +118,14 @@ function Explorer( objective = nothing; settings = ExploreSettings(domains), ) - F1 = isempty(concepts) ? Function : Union{map(c -> typeof(c[1]), concepts)...} - D = isempty(domains) ? AbstractDomain : Union{map(typeof, domains)...} - F2 = typeof(objective) - T = isempty(domains) ? Real : Union{map(eltype, domains)...} + if settings.search == :flexible + settings.search = + settings.max_samplings < settings.complete_search_limit ? :complete : :partial + end + state = explorer_state(domains, Val(settings.search)) d_c = Dict(enumerate(concepts)) d_d = Dict(enumerate(domains)) - return Explorer{F1,D,F2,T}(d_c, d_d, objective, settings, ExplorerState{T}()) + return Explorer(d_c, d_d, objective, settings, state) end function Explorer() @@ -225,15 +256,14 @@ function update_exploration!(explorer, f, c, search = explorer.settings.search) obj = explorer.objective sl = search == :complete ? Inf : explorer.settings.solutions_limit - cv = collect(c) - if f(cv) + if f(c) if length(solutions) < sl - push!(solutions, cv) + push!(solutions, c) obj !== nothing && (explorer.state.best = argmin(obj, solutions)) end else if length(non_sltns) < sl - push!(non_sltns, cv) + push!(non_sltns, c) end end return nothing @@ -261,7 +291,9 @@ function _explore!(explorer, f, ::Val{:partial};) end function _explore!(explorer, f, ::Val{:complete}) - C = Base.Iterators.product(map(d -> get_domain(d), explorer.domains |> values)...) + C = Base.Iterators.product( + Iterators.map(d -> get_domain(d), explorer.domains |> values)..., + ) foreach(c -> update_exploration!(explorer, f, c, :complete), C) return nothing end @@ -293,12 +325,7 @@ function explore!(explorer::Explorer) f(isempty(vars) ? x : @view x[vars]) for (f, vars) in explorer.concepts |> values ]) - s = explorer.settings - search = s.search - if search == :flexible - search = s.max_samplings < s.complete_search_limit ? :complete : :partial - end - return _explore!(explorer, c, Val(search)) + return _explore!(explorer, c, Val(explorer.settings.search)) end @@ -346,5 +373,5 @@ end @test length(X) == factorial(4) @test length(X̅) == 4^4 - factorial(4) - explorer = ConstraintDomains.Explorer() + explorer = ConstraintDomains.Explorer([(allunique, 1:4)], domains) end