diff --git a/src/fit.jl b/src/fit.jl index ee081de..20d878e 100644 --- a/src/fit.jl +++ b/src/fit.jl @@ -77,7 +77,8 @@ end "element-wise trainer" trainepoch_E!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(), :queue => QTrees.thread_queue(), - :pairlist => Vector{QTrees.CoItem}(), + :itemlist => Vector{QTrees.CoItem}(), + :pairlist => Vector{Tuple{Int, Int}}(), :updated => Set{Int}(), :spqtree => QTrees.hash_spacial_qtree(inputs)) trainepoch_E!(s::Symbol) = get(Dict(:patient => 20, :nepoch => 2000), s, nothing) @@ -101,7 +102,8 @@ end trainepoch_EM!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(), :queue => QTrees.thread_queue(), :memory => intlru(length(inputs)), - :pairlist => Vector{QTrees.CoItem}(), + :itemlist => Vector{QTrees.CoItem}(), + :pairlist => Vector{Tuple{Int, Int}}(), :updated => Set{Int}(), :spqtree => QTrees.hash_spacial_qtree(inputs)) trainepoch_EM!(s::Symbol) = get(Dict(:patient => 10, :nepoch => 1000), s, nothing) @@ -212,11 +214,12 @@ end "dynamic trainer" trainepoch_D!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(), :queue => QTrees.thread_queue(), - :pairlist => Vector{QTrees.CoItem}(), + :itemlist => Vector{QTrees.CoItem}(), + :pairlist => Vector{Tuple{Int, Int}}(), :updated => QTrees.UpdatedSet(1:length(inputs)), :loops => 10, :spqtree => QTrees.linked_spacial_qtree(inputs), #fllowing 4 tiems: pre-allocating for dynamiccollisions - :sptree2 => QTrees.hash_spacial_qtree(inputs), + :sptqree2 => QTrees.hash_spacial_qtree(inputs), :lbcollector => Vector{Int}(), :treenodestack => Vector{QTrees.SpacialQTreeNode}()) trainepoch_D!(s::Symbol) = get(Dict(:patient => 10, :nepoch => 2000), s, nothing) diff --git a/src/qtree_functions.jl b/src/qtree_functions.jl index 3495e55..45f8f87 100644 --- a/src/qtree_functions.jl +++ b/src/qtree_functions.jl @@ -96,9 +96,12 @@ function _totalcollisions_native(qtrees::AbstractVector, coitems::Vector{CoItem} colist end function _totalcollisions_native(qtrees::AbstractVector, - labels::AbstractVector{<:Integer}=1:length(qtrees); kargs...) + labels::AbstractVector{<:Integer}=1:length(qtrees); + pairlist::AbstractVector{Tuple{Int, Int}}=Vector{Tuple{Int, Int}}(), kargs...) l = length(labels) - _totalcollisions_native(qtrees, [@inbounds (labels[i], labels[j]) for i in 1:l for j in l:-1:i + 1]; kargs...) + empty!(pairlist) + append!(pairlist, (@inbounds (labels[i], labels[j]) for i in 1:l for j in l:-1:i + 1)) + _totalcollisions_native(qtrees, pairlist; kargs...) end function _totalcollisions_native(qtrees::AbstractVector, labels::AbstractSet{<:Integer}; kargs...) _totalcollisions_native(qtrees, labels |> collect; kargs...) @@ -155,15 +158,15 @@ function locate!(qts::AbstractVector, labels::Union{AbstractVector{Int},Abstract spqtree end -function collisions_boundsfilter(qtrees, spindex, lowlabels, higlabels, pairlist, colist) +function collisions_boundsfilter(qtrees, spindex, lowlabels, higlabels, itemlist, colist) for hlb in higlabels # check here because there are no bounds checking in _collision_randbfs - collisions_boundsfilter(qtrees, spindex, lowlabels, hlb, pairlist, colist) + collisions_boundsfilter(qtrees, spindex, lowlabels, hlb, itemlist, colist) end end -function collisions_boundsfilter(qtrees, spindex, lowlabels, hlb::Int, pairlist, colist) +function collisions_boundsfilter(qtrees, spindex, lowlabels, hlb::Int, itemlist, colist) if inkernelbounds(@inbounds(qtrees[hlb][spindex[1]]), spindex[2], spindex[3]) - append!(pairlist, ((llb, hlb)=>spindex for llb in lowlabels)) + append!(itemlist, ((llb, hlb)=>spindex for llb in lowlabels)) elseif getdefault(@inbounds(qtrees[hlb][1])) == QTrees.FULL for llb in lowlabels if @inbounds(qtrees[llb][spindex]) != QTrees.EMPTY @@ -173,22 +176,22 @@ function collisions_boundsfilter(qtrees, spindex, lowlabels, hlb::Int, pairlist, end end end -function collisions_boundsfilter(qtrees, spindex, llb::Int, higlabels, pairlist, colist) - collisions_boundsfilter(qtrees, spindex, (llb,), higlabels, pairlist, colist) +function collisions_boundsfilter(qtrees, spindex, llb::Int, higlabels, itemlist, colist) + collisions_boundsfilter(qtrees, spindex, (llb,), higlabels, itemlist, colist) end @assert collect(Iterators.product(1:2, 4:6))[1] == (1, 4) function totalcollisions_spacial(qtrees::AbstractVector, spqtree::HashSpacialQTree; - colist=Vector{CoItem}(), pairlist::AbstractVector{CoItem}=Vector{CoItem}(), unique=true, kargs...) + colist=Vector{CoItem}(), itemlist::AbstractVector{CoItem}=Vector{CoItem}(), unique=true, kargs...) length(qtrees) > 1 || return colist nlevel = length(@inbounds qtrees[1]) - empty!(pairlist) + empty!(itemlist) for spindex in keys(spqtree) labels = spqtree[spindex] labelslen = length(labels) if labelslen > 1 for i in 1:labelslen for j in labelslen:-1:i+1 - push!(pairlist, (@inbounds labels[i], @inbounds labels[j]) => spindex) + push!(itemlist, (@inbounds labels[i], @inbounds labels[j]) => spindex) end end end @@ -198,12 +201,12 @@ function totalcollisions_spacial(qtrees::AbstractVector, spqtree::HashSpacialQTr (@inbounds pspindex[1] > nlevel) && break if haskey(spqtree, pspindex) plbs = spqtree[pspindex] - collisions_boundsfilter(qtrees, spindex, labels, plbs, pairlist, colist) + collisions_boundsfilter(qtrees, spindex, labels, plbs, itemlist, colist) end end end - # @show length(pairlist), length(colist) - r = _totalcollisions_native(qtrees, pairlist; colist=colist, kargs...) + # @show length(itemlist), length(colist) + r = _totalcollisions_native(qtrees, itemlist; colist=colist, kargs...) unique ? unique!(first, sort!(r)) : r end function totalcollisions_spacial(qtrees::AbstractVector{U8SQTree}; @@ -218,24 +221,25 @@ function totalcollisions_spacial(qtrees::AbstractVector{U8SQTree}, labels::Union end const SPACIAL_ENABLE_THRESHOLD = round(Int, 10+10log2(Threads.nthreads())) -function totalcollisions_native_kw(qtrees, args...; pairlist=nothing, unique=true, spqtree=nothing, kargs...) - totalcollisions_native(qtrees, args...; kargs...) +function totalcollisions_native_kw(args...; itemlist=nothing, unique=true, spqtree=nothing, kargs...) + totalcollisions_native(args...; kargs...) end -function totalcollisions(qtrees::AbstractVector{U8SQTree}, args...; kargs...) - if length(qtrees) > SPACIAL_ENABLE_THRESHOLD - return totalcollisions_spacial(qtrees, args...; kargs...) +totalcollisions_spacial_kw(args...; pairlist=nothing, kargs...) = totalcollisions_spacial(args...; kargs...) +function totalcollisions(args...; kargs...) + if length(args[end]) > SPACIAL_ENABLE_THRESHOLD + return totalcollisions_spacial_kw(args...; kargs...) else - return totalcollisions_native_kw(qtrees, args...; kargs...) + return totalcollisions_native_kw(args...; kargs...) end end function partialcollisions(qtrees::AbstractVector, spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees), labels::AbstractSet{Int}=Set(1:length(qtrees)); - colist=Vector{CoItem}(), pairlist::AbstractVector{CoItem}=Vector{CoItem}(), + colist=Vector{CoItem}(), itemlist::AbstractVector{CoItem}=Vector{CoItem}(), lbcollector = Vector{Int}(), treenodestack = Vector{SpacialQTreeNode}(), unique=true, kargs...) - empty!(pairlist) + empty!(itemlist) locate!(qtrees, labels, spqtree) #需要将labels中的label移动到链表首 for label in labels # @show label @@ -250,13 +254,13 @@ function partialcollisions(qtrees::AbstractVector, # 但要保证更prev的node在`labels`中 treenode = seek_treenode(ln) spindex = spacial_index(treenode) - append!(pairlist, (((label, lb) => spindex) for lb in lbcollector)) + append!(itemlist, (((label, lb) => spindex) for lb in lbcollector)) tn = treenode while !isroot(tn) tn = tn.parent #root不是哨兵,值需要遍历 if !isemptylabels(tn) plbs = Iterators.filter(!in(labels), labelsof(tn)) #moved了的plb不加入,等候其向下遍历时加,避免重复 - collisions_boundsfilter(qtrees, spindex, label, plbs, pairlist, colist) + collisions_boundsfilter(qtrees, spindex, label, plbs, itemlist, colist) end end empty!(treenodestack) @@ -271,13 +275,13 @@ function partialcollisions(qtrees::AbstractVector, cspindex = spacial_index(tn) clbs = labelsof(tn) # @show cspindex clbs - collisions_boundsfilter(qtrees, cspindex, clbs, label, pairlist, colist) + collisions_boundsfilter(qtrees, cspindex, clbs, label, itemlist, colist) end for c in tn.children if !isemptychild(tn, c) #如果isemptychild则该child无意义 emptyflag = false push!(treenodestack, c) - # @show pairlist + # @show itemlist end end emptyflag && remove_tree_node(spqtree, tn) @@ -285,8 +289,8 @@ function partialcollisions(qtrees::AbstractVector, end end empty!(labels) - # @show length(pairlist), length(colist) - r = _totalcollisions_native(qtrees, pairlist; colist=colist, kargs...) + # @show length(itemlist), length(colist) + r = _totalcollisions_native(qtrees, itemlist; colist=colist, kargs...) unique ? unique!(first, sort!(r)) : r end mutable struct UpdatedSet{T} <: AbstractSet{T} @@ -305,11 +309,11 @@ Base.length(s::UpdatedSet) = length(s.set) Base.iterate(s::UpdatedSet, args...) = iterate(s.set, args...) Base.in(item, s::UpdatedSet) = in(item, s.set) Base.in(s::UpdatedSet) = in(s.set) -function totalcollisions_kw(qtrees; sptree2=hash_spacial_qtree(qtrees), +function totalcollisions_kw(qtrees; sptqree2=hash_spacial_qtree(qtrees), spqtree=nothing, lbcollector=nothing, treenodestack=nothing, kargs...) - totalcollisions(qtrees; spqtree=sptree2, kargs...) + totalcollisions(qtrees; spqtree=sptqree2, kargs...) end -partialcollisions_kw(qtrees, spqtree, updated; sptree2=nothing, kargs...) = partialcollisions(qtrees, spqtree, updated; kargs...) +partialcollisions_kw(qtrees, spqtree, updated; sptqree2=nothing, pairlist=nothing, kargs...) = partialcollisions(qtrees, spqtree, updated; kargs...) function dynamiccollisions(qtrees::AbstractVector, spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees), updated::UpdatedSet{Int}=UpdatedSet(1:length(qtrees));