Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ LuxorTensorPlot = ["LuxorGraphPlot"]
AbstractTrees = "0.3, 0.4"
Aqua = "0.8"
CliqueTrees = "1.12.1"
DataStructures = "0.18"
DataStructures = "0.19"
Documenter = "1.10.1"
Graphs = "1"
JSON = "0.21"
JSON = "0.21, 1"
KaHyPar = "0.3"
LuxorGraphPlot = "0.5"
OMEinsum = "0.9"
Expand Down
2 changes: 1 addition & 1 deletion src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Base.Threads
using AbstractTrees
using TreeWidthSolver
using TreeWidthSolver.Graphs
using DataStructures: PriorityQueue, enqueue!, dequeue!, peek, dequeue_pair!
using DataStructures: PriorityQueue
import CliqueTrees
using CliqueTrees: cliquetree, cliquetree!, separator, residual, CliqueTree, EliminationAlgorithm, MMW, BFS, MCS, LexBFS, RCMMD, RCMGL, MCSM, LexM, AMF, MF, MMD, MF, BT, SafeRules, KaHyParND, METISND, ND, BestWidth, ConnectedComponents

Expand Down
14 changes: 7 additions & 7 deletions src/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function evaluate_costs(α::TA, incidence_list::IncidenceList{Int,ET}, log2_edge
for vi in vertices(incidence_list)
for vj in neighbors(incidence_list, vi)
if vj > vi
enqueue!(cost_values, (vi,vj), greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj))
push!(cost_values, (vi,vj) => greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj))
add_edge!(cost_graph, vi, vj)
end
end
Expand All @@ -95,9 +95,9 @@ function update_costs!(cost_values, cost_graph, va, vb, α::TA, incidence_list::
vx, vy = minmax(vj, va)
if has_edge(cost_graph, vx, vy)
delete!(cost_values, (vx,vy))
enqueue!(cost_values, (vx,vy), greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
push!(cost_values, (vx,vy) => greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
else
enqueue!(cost_values, (vx,vy), greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
push!(cost_values, (vx,vy) => greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
add_edge!(cost_graph, vx, vy)
end
end
Expand All @@ -110,18 +110,18 @@ end

function find_best_cost!(temperature::TT, cost_values::PriorityQueue{PT}, cost_graph) where {PT,TT}
length(cost_values) < 1 && error("cost value information missing")
pair, cost = dequeue_pair!(cost_values)
pair, cost = popfirst!(cost_values)
if iszero(temperature) || isempty(cost_values)
rem_edge!(cost_graph, pair...)
return pair
else
pair2, cost2 = dequeue_pair!(cost_values)
pair2, cost2 = popfirst!(cost_values)
if rand() < exp(-(cost2 - cost) / temperature) # pick 2
enqueue!(cost_values, pair, cost)
push!(cost_values, pair => cost)
rem_edge!(cost_graph, pair2...)
return pair2
else
enqueue!(cost_values, pair2, cost2)
push!(cost_values, pair2 => cost2)
rem_edge!(cost_graph, pair...)
return pair
end
Expand Down
4 changes: 2 additions & 2 deletions src/json.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ function einstodict(eins::EinCode)
return Dict("ixs"=>ixs, "iy"=>iy)
end

function fromdict(::Type{LT}, dict::Dict) where LT
function fromdict(::Type{LT}, dict::AbstractDict) where LT
if dict["isleaf"]
return NestedEinsum{LT}(dict["tensorindex"])
end
eins = einsfromdict(LT, dict["eins"])
return NestedEinsum(fromdict.(LT, dict["args"]), eins)
end

function einsfromdict(::Type{LT}, dict::Dict) where LT
function einsfromdict(::Type{LT}, dict::AbstractDict) where LT
return EinCode([collect(LT, _convert.(LT, ix)) for ix in dict["ixs"]], collect(LT, _convert.(LT, dict["iy"])))
end

Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function merge_greedy(code::EinCode{LT}, size_dict; threshhold=-1e-12) where LT
if isempty(cost_values)
return _buildsimplifier(tree, incidence_list)
end
pair, v = peek(cost_values)
pair, v = first(cost_values)
if v <= threshhold
_, _, c = contract_pair!(incidence_list, pair..., log2_edge_sizes)
tree[pair[1]] = NestedEinsum([tree[pair[1]], tree[pair[2]]], EinCode([c.first...], c.second))
Expand Down
Loading