Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ LuxorTensorPlot = ["LuxorGraphPlot"]
AbstractTrees = "0.3, 0.4"
Aqua = "0.8"
CliqueTrees = "1.12.1"
DataStructures = "0.18"
DataStructures = "0.19"
Documenter = "1.10.1"
Graphs = "1"
JSON = "0.21"
JSON = "0.21, 1"
KaHyPar = "0.3"
LuxorGraphPlot = "0.5"
OMEinsum = "0.9"
Expand Down
2 changes: 1 addition & 1 deletion src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Base.Threads
using AbstractTrees
using TreeWidthSolver
using TreeWidthSolver.Graphs
using DataStructures: PriorityQueue, enqueue!, dequeue!, peek, dequeue_pair!
using DataStructures: PriorityQueue
import CliqueTrees
using CliqueTrees: cliquetree, cliquetree!, separator, residual, CliqueTree, EliminationAlgorithm, MMW, BFS, MCS, LexBFS, RCMMD, RCMGL, MCSM, LexM, AMF, MF, MMD, MF, BT, SafeRules, KaHyParND, METISND, ND, BestWidth, ConnectedComponents

Expand Down
12 changes: 6 additions & 6 deletions src/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function evaluate_costs(α::TA, incidence_list::IncidenceList{Int,ET}, log2_edge
for vi in vertices(incidence_list)
for vj in neighbors(incidence_list, vi)
if vj > vi
enqueue!(cost_values, (vi,vj), greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj))
push!(cost_values, (vi,vj) => greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj))
add_edge!(cost_graph, vi, vj)
end
end
Expand All @@ -95,9 +95,9 @@ function update_costs!(cost_values, cost_graph, va, vb, α::TA, incidence_list::
vx, vy = minmax(vj, va)
if has_edge(cost_graph, vx, vy)
delete!(cost_values, (vx,vy))
enqueue!(cost_values, (vx,vy), greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
push!(cost_values, (vx,vy) => greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
else
enqueue!(cost_values, (vx,vy), greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
push!(cost_values, (vx,vy) => greedy_loss(α, incidence_list, log2_edge_sizes, vx, vy))
add_edge!(cost_graph, vx, vy)
end
end
Expand All @@ -110,18 +110,18 @@ end

function find_best_cost!(temperature::TT, cost_values::PriorityQueue{PT}, cost_graph) where {PT,TT}
length(cost_values) < 1 && error("cost value information missing")
pair, cost = dequeue_pair!(cost_values)
pair, cost = popfirst!(cost_values)
if iszero(temperature) || isempty(cost_values)
rem_edge!(cost_graph, pair...)
return pair
else
pair2, cost2 = dequeue_pair!(cost_values)
if rand() < exp(-(cost2 - cost) / temperature) # pick 2
enqueue!(cost_values, pair, cost)
push!(cost_values, pair => cost)
rem_edge!(cost_graph, pair2...)
return pair2
else
enqueue!(cost_values, pair2, cost2)
push!(cost_values, pair2 => cost2)
rem_edge!(cost_graph, pair...)
return pair
end
Expand Down
4 changes: 2 additions & 2 deletions src/json.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ function einstodict(eins::EinCode)
return Dict("ixs"=>ixs, "iy"=>iy)
end

function fromdict(::Type{LT}, dict::Dict) where LT
function fromdict(::Type{LT}, dict::AbstractDict) where LT
if dict["isleaf"]
return NestedEinsum{LT}(dict["tensorindex"])
end
eins = einsfromdict(LT, dict["eins"])
return NestedEinsum(fromdict.(LT, dict["args"]), eins)
end

function einsfromdict(::Type{LT}, dict::Dict) where LT
function einsfromdict(::Type{LT}, dict::AbstractDict) where LT
return EinCode([collect(LT, _convert.(LT, ix)) for ix in dict["ixs"]], collect(LT, _convert.(LT, dict["iy"])))
end

Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function merge_greedy(code::EinCode{LT}, size_dict; threshhold=-1e-12) where LT
if isempty(cost_values)
return _buildsimplifier(tree, incidence_list)
end
pair, v = peek(cost_values)
pair, v = first(cost_values)
if v <= threshhold
_, _, c = contract_pair!(incidence_list, pair..., log2_edge_sizes)
tree[pair[1]] = NestedEinsum([tree[pair[1]], tree[pair[2]]], EinCode([c.first...], c.second))
Expand Down
Loading