diff --git a/SeQuant/core/eval/eval_expr.cpp b/SeQuant/core/eval/eval_expr.cpp index e805cea82e..285130e8bd 100644 --- a/SeQuant/core/eval/eval_expr.cpp +++ b/SeQuant/core/eval/eval_expr.cpp @@ -461,6 +461,7 @@ EvalExprNode binarize(Product const& prod, IndexSet const& uncontract) { auto counts = get_used_indices_with_counts(prod); IndexGroups result; for (auto&& [k, v] : counts) { + if (v.nonproto() == 0) continue; if (v.total() > 1) { if (uncontracted_idxs.contains(k)) result.aux.emplace_back(k); continue; diff --git a/SeQuant/core/optimize/single_term.hpp b/SeQuant/core/optimize/single_term.hpp index 345d0f10f3..3699f472fc 100644 --- a/SeQuant/core/optimize/single_term.hpp +++ b/SeQuant/core/optimize/single_term.hpp @@ -1,15 +1,20 @@ #ifndef SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP #define SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP +#include #include #include +#include #include #include #include +#include #include +#include #include +#include #include namespace sequant::opt { @@ -42,22 +47,83 @@ auto constexpr flops_counter(has_index_extent auto&& ixex) { /// struct OptRes { /// Free indices remaining upon evaluation - container::svector indices; + IndexSet indices; /// The flops count of evaluation double flops; /// The evaluation sequence EvalSequence sequence; + + /// Bitmask splits that resulted into this OptRes + size_t lp = 0; + size_t rp = 0; + + /// unique canonical subnets in the optimal tree for this bitmask + container::vector subnets; +}; + +struct SubNetHash { + size_t operator()( + TensorNetwork::SlotCanonicalizationMetadata const& data) const noexcept { + return data.hash_value(); + } +}; + +struct SubNetEqual { + bool operator()( + TensorNetwork::SlotCanonicalizationMetadata const& left, + TensorNetwork::SlotCanonicalizationMetadata const& right) const { + return bliss::ConstGraphCmp::cmp(*left.graph, *right.graph) == 0; + } }; +/// \brief Finds the optimal evaluation sequence for a single-term tensor +/// contraction. +/// +/// This function employs an exhaustive search using dynamic programming to +/// determine the contraction order that minimizes the total cost, as defined by +/// the provided cost function. +/// +/// \tparam CostFn A function object type that computes the cost of a single +/// binary contraction. +/// Expected signature: +/// \code double(meta::range_of auto const& lhs, +/// meta::range_of auto const& rhs, +/// meta::range_of auto const& res) +/// \endcode +/// +/// \param network The \ref TensorNetwork containing the tensors to be +/// contracted. +/// \param tidxs The set of indices that should remain open in the +/// final result. +/// \param cost_fn The cost model used to evaluate contractions +/// (e.g., flop count). +/// \param subnet_cse If true, enables Common Subexpression +/// Elimination (CSE) for +/// equivalent subnetworks. When enabled, the cost of +/// evaluating structurally identical subnetworks is counted +/// only once in the total cost of a contraction tree. +/// Equivalence is determined by canonicalizing the subnetwork +/// graph. +/// +/// \return An \ref EvalSequence representing the optimal contraction order. +/// +/// \details The optimization uses a bitmask-based dynamic programming approach +/// where each state represents a subnetwork (subset of tensors). +/// If \p subnet_cse is enabled, the algorithm precomputes canonical +/// metadata for every possible subnetwork to identify common +/// structures. This allows it to find trees that benefit from reusing +/// intermediate results, which is particularly effective for +/// expressions with repeating tensor patterns. +/// template requires requires(CostFn&& fn, decltype(OptRes::indices) const& ixs) { - { std::forward(fn)(ixs, ixs, ixs) } -> std::floating_point; + { fn(ixs, ixs, ixs) } -> std::floating_point; } EvalSequence single_term_opt_impl(TensorNetwork const& network, meta::range_of auto const& tidxs, - CostFn&& cost_fn) { + CostFn&& cost_fn, bool subnet_cse) { using ranges::views::concat; using ranges::views::indirect; using ranges::views::transform; @@ -88,26 +154,96 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network, } } + // precompute all subnet_meta if subnet_cse is true + // Note: the O(2^n) cost is bounded in practice — subset_target_indices above + // asserts n <= 24, capping the number of subsets at ~16M. + container::vector meta_ids; + container::vector unique_meta_costs; + if (subnet_cse) { + // Use max as sentinel for entries with popcount < 2 (singletons/empty), + // which are skipped below and never assigned a real meta ID. + meta_ids.resize(results.size(), std::numeric_limits::max()); + container::unordered_map + meta_to_id; + + for (size_t n = 0; n < results.size(); ++n) { + if (std::popcount(n) < 2) continue; + auto ts = bits::on_bits_index(n) | bits::sieve(network.tensors()); + container::vector ts_expr; + for (auto&& t : ts) { + ts_expr.emplace_back(std::dynamic_pointer_cast(t)->clone()); + } + auto tn = TensorNetwork{ts_expr}; + auto meta = tn.canonicalize_slots( + TensorCanonicalizer::cardinal_tensor_labels(), &results[n].indices); + + auto [it, inserted] = meta_to_id.try_emplace(std::move(meta), 0); + if (inserted) { + it->second = meta_to_id.size() - 1; + } + meta_ids[n] = it->second; + } + unique_meta_costs.resize(meta_to_id.size(), 0.0); + } + // find the optimal evaluation sequence for (size_t n = 0; n < results.size(); ++n) { if (std::popcount(n) < 2) continue; - std::pair curr_parts{0, 0}; for (auto& curr_cost = results[n].flops; auto&& [lp, rp] : bits::bipartitions(n)) { // do nothing with the trivial bipartition // i.e. one subset is the empty set and the other full if (lp == 0 || rp == 0) continue; - auto new_cost = std::forward(cost_fn)(results[lp].indices, // - results[rp].indices, // - results[n].indices) // - + results[lp].flops + results[rp].flops; + + double new_cost = 0; + container::vector combined_subnets; + if (subnet_cse) { + // subnets is always kept sorted; set_union requires sorted inputs and + // produces sorted output — this invariant is maintained throughout. + std::set_union(results[lp].subnets.begin(), results[lp].subnets.end(), + results[rp].subnets.begin(), results[rp].subnets.end(), + std::back_inserter(combined_subnets)); + new_cost = cost_fn(results[lp].indices, // + results[rp].indices, // + results[n].indices); + for (auto id : combined_subnets) { + new_cost += unique_meta_costs[id]; + } + } else { + new_cost = cost_fn(results[lp].indices, // + results[rp].indices, // + results[n].indices) // + + results[lp].flops + results[rp].flops; + } + if (new_cost <= curr_cost) { curr_cost = new_cost; - curr_parts = decltype(curr_parts){lp, rp}; + results[n].lp = lp; + results[n].rp = rp; + if (subnet_cse) { + results[n].subnets = std::move(combined_subnets); + } } } - auto const& lseq = results[curr_parts.first].sequence; - auto const& rseq = results[curr_parts.second].sequence; + + if (subnet_cse) { + auto mid = meta_ids[n]; + // Canonically equivalent subnetworks share the same topology and index + // sizes, so their cost is identical. Overwriting with a later bitmask's + // cost is intentional and benign. + unique_meta_costs[mid] = + cost_fn(results[results[n].lp].indices, + results[results[n].rp].indices, results[n].indices); + auto it = std::lower_bound(results[n].subnets.begin(), + results[n].subnets.end(), mid); + if (it == results[n].subnets.end() || *it != mid) { + results[n].subnets.insert(it, mid); + } + } + + auto const& lseq = results[results[n].lp].sequence; + auto const& rseq = results[results[n].rp].sequence; results[n].sequence = (lseq[0] < rseq[0] ? concat(lseq, rseq) : concat(rseq, lseq)) | ranges::to; @@ -121,15 +257,19 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network, /// \tparam IdxToSz /// \param network A TensorNetwork object. /// \param idxsz An invocable on Index, that maps Index to its dimension. -/// \return Optimal evaluation sequence that minimizes flops. If there are +/// \param subnet_cse Whether to recognize equivalent subnetworks to try +/// minimizing the ops counts. +/// \return Optimal evaluation sequence that +/// minimizes flops. If there are /// equivalent optimal sequences then the result is the one that keeps /// the order of tensors in the network as original as possible. /// template -EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) { +EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz, + bool subnet_cse) { auto cost_fn = flops_counter(std::forward(idxsz)); decltype(OptRes::indices) tidxs{}; - return single_term_opt_impl(network, tidxs, cost_fn); + return single_term_opt_impl(network, tidxs, cost_fn, subnet_cse); } } // namespace detail @@ -142,7 +282,8 @@ EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) { /// @note @c prod is assumed to consist of only Tensor expressions /// template -ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) { +ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz, + bool subnet_cse = false) { using ranges::views::filter; using ranges::views::reverse; @@ -152,7 +293,7 @@ ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) { auto const tensors = prod | filter(&ExprPtr::template is) | ranges::to_vector; auto seq = detail::single_term_opt(TensorNetwork{tensors}, - std::forward(idxsz)); + std::forward(idxsz), subnet_cse); auto result = container::svector{}; for (auto i : seq) if (i == -1) { diff --git a/tests/unit/test_optimize.cpp b/tests/unit/test_optimize.cpp index 98eab7803c..8049c610a8 100644 --- a/tests/unit/test_optimize.cpp +++ b/tests/unit/test_optimize.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include @@ -323,9 +322,103 @@ TEST_CASE("optimize", "[optimize]") { } } + SECTION("Single term optimization with CSE") { + auto ctx_resetter = + set_scoped_default_context(get_default_context().clone()); + auto reg = get_default_context().mutable_index_space_registry(); + mbpt::add_df_spaces(reg); + mbpt::add_pao_spaces(reg); + mbpt::add_ao_spaces(reg); + // i 10 + // a 40 + // μ̃ 50 + // Κ 90 + for (auto&& [k, v] : + std::initializer_list>{ + {L"i", 10}, {L"a", 40}, {L"μ̃", 50}, {L"Κ", 90}}) { + reg->retrieve_ptr(k)->approximate_size(v); + } + + auto single_term_opt = [](Product const& prod, bool cse = true) { + return opt::single_term_opt( + prod, + [](Index const& ix) { + // null space contributes x1 to the size + auto sz = ix.nonnull() ? ix.space().approximate_size() : 1; + return sz; + }, + /*subnet_cse=*/cse); + }; + + auto prod9 = + deserialize("X{i1;a1} X{i2;a2} Y{a2;i3} Y{a1;i4}")->as(); + auto res9 = single_term_opt(prod9); + auto res9_no_cse = single_term_opt(prod9, false); + // this is the one we want to find + // (X Y) (X Y) + REQUIRE(extract(res9, {0, 0}) == prod9.at(0)); + REQUIRE(extract(res9, {0, 1}) == prod9.at(3)); + REQUIRE(extract(res9, {1, 0}) == prod9.at(1)); + REQUIRE(extract(res9, {1, 1}) == prod9.at(2)); + + // take a look at res9_no_cse for a result with subnet_cse disabled + // should give the same result in this case as it's already optimal + REQUIRE(extract(res9_no_cse, {0, 0}) == prod9.at(0)); + REQUIRE(extract(res9_no_cse, {0, 1}) == prod9.at(3)); + REQUIRE(extract(res9_no_cse, {1, 0}) == prod9.at(1)); + REQUIRE(extract(res9_no_cse, {1, 1}) == prod9.at(2)); + + SECTION("CSE effect on optimization result") { + auto ctx_resetter = + set_scoped_default_context(get_default_context().clone()); + auto reg = get_default_context().mutable_index_space_registry(); + // Use sizes that make the unbalanced tree better without CSE, + // but the balanced tree better with CSE. + // Balanced: ( (X1 Y1) (X2 Y2) ) + // Cost(X1*Y1) = size(i)*size(a)*size(j) = 12*10*12 = 1440. + // Cost(Inter) = 12^3 = 1728. + // Total no-CSE: 2*1440 + 1728 = 4608. + // Total CSE: 1440 + 1728 = 3168. + // Unbalanced: ( ( (X1 Y1) X2 ) Y2 ) + // Cost(X1*Y1) = 12*10*12 = 1440. + // Cost((X1*Y1)*X2) = size(i)*size(i)*size(a) = 12*12*10 = 1440. + // Cost(...) * Y2 = 12*10*12 = 1440. + // Total Unbalanced: 1440 + 1440 + 1440 = 4320. + // 3168 < 4320 < 4608. + reg->retrieve_ptr(L"i")->approximate_size(12); + reg->retrieve_ptr(L"a")->approximate_size(10); + + auto single_term_opt = [](Product const& prod, bool cse) { + return opt::single_term_opt( + prod, + [](Index const& ix) { + return ix.nonnull() ? ix.space().approximate_size() : 1; + }, + cse); + }; + + // X{i1;a1} Y{a1;i2} X{i2;a2} Y{a2;i3} + auto prod = + deserialize(L"X{i1;a1} Y{a1;i2} X{i2;a2} Y{a2;i3}")->as(); + + auto res_cse = single_term_opt(prod, true); + auto res_no_cse = single_term_opt(prod, false); + + // With CSE: Balanced tree + REQUIRE(res_cse->as().factors().size() == 2); + REQUIRE(res_cse->at(0)->is()); + REQUIRE(res_cse->at(1)->is()); + + // Without CSE: Unbalanced tree + bool is_unbalanced = + (res_no_cse->at(0)->is() || res_no_cse->at(1)->is()); + REQUIRE(is_unbalanced); + } + } + /// verify that space changes did not leak - auto reg = get_default_context().index_space_registry(); - auto uocc = reg->retrieve_ptr(L"a"); - REQUIRE(uocc); - REQUIRE(uocc->approximate_size() == 10); + auto reg_check = get_default_context().index_space_registry(); + auto uocc_check = reg_check->retrieve_ptr(L"a"); + REQUIRE(uocc_check); + REQUIRE(uocc_check->approximate_size() == 10); }