Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions SeQuant/core/eval/eval_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ EvalExprNode binarize(Product const& prod, IndexSet const& uncontract) {
auto counts = get_used_indices_with_counts(prod);
IndexGroups<IndexVec> result;
for (auto&& [k, v] : counts) {
if (v.nonproto() == 0) continue;
if (v.total() > 1) {
if (uncontracted_idxs.contains(k)) result.aux.emplace_back(k);
continue;
Expand Down
173 changes: 157 additions & 16 deletions SeQuant/core/optimize/single_term.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#ifndef SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
#define SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP

#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/tensor_canonicalizer.hpp>
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/external/bliss/graph.hh>

#include <range/v3/view.hpp>

#include <algorithm>
#include <bit>
#include <limits>
#include <type_traits>

namespace sequant::opt {
Expand Down Expand Up @@ -42,22 +47,83 @@ auto constexpr flops_counter(has_index_extent auto&& ixex) {
///
struct OptRes {
/// Free indices remaining upon evaluation
container::svector<sequant::Index> indices;
IndexSet indices;

/// The flops count of evaluation
double flops;

/// The evaluation sequence
EvalSequence sequence;

/// Bitmask splits that resulted into this OptRes
size_t lp = 0;
size_t rp = 0;

/// unique canonical subnets in the optimal tree for this bitmask
container::vector<size_t> subnets;
};

struct SubNetHash {
size_t operator()(
TensorNetwork::SlotCanonicalizationMetadata const& data) const noexcept {
return data.hash_value();
}
};

struct SubNetEqual {
bool operator()(
TensorNetwork::SlotCanonicalizationMetadata const& left,
TensorNetwork::SlotCanonicalizationMetadata const& right) const {
return bliss::ConstGraphCmp::cmp(*left.graph, *right.graph) == 0;
}
};

/// \brief Finds the optimal evaluation sequence for a single-term tensor
/// contraction.
///
/// This function employs an exhaustive search using dynamic programming to
/// determine the contraction order that minimizes the total cost, as defined by
/// the provided cost function.
///
/// \tparam CostFn A function object type that computes the cost of a single
/// binary contraction.
/// Expected signature:
/// \code double(meta::range_of<Index> auto const& lhs,
/// meta::range_of<Index> auto const& rhs,
/// meta::range_of<Index> auto const& res)
/// \endcode
///
/// \param network The \ref TensorNetwork containing the tensors to be
/// contracted.
/// \param tidxs The set of indices that should remain open in the
/// final result.
/// \param cost_fn The cost model used to evaluate contractions
/// (e.g., flop count).
/// \param subnet_cse If true, enables Common Subexpression
/// Elimination (CSE) for
/// equivalent subnetworks. When enabled, the cost of
/// evaluating structurally identical subnetworks is counted
/// only once in the total cost of a contraction tree.
/// Equivalence is determined by canonicalizing the subnetwork
/// graph.
///
/// \return An \ref EvalSequence representing the optimal contraction order.
///
/// \details The optimization uses a bitmask-based dynamic programming approach
/// where each state represents a subnetwork (subset of tensors).
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
/// metadata for every possible subnetwork to identify common
/// structures. This allows it to find trees that benefit from reusing
/// intermediate results, which is particularly effective for
/// expressions with repeating tensor patterns.
///
template <typename CostFn>
requires requires(CostFn&& fn, decltype(OptRes::indices) const& ixs) {
{ std::forward<CostFn>(fn)(ixs, ixs, ixs) } -> std::floating_point;
{ fn(ixs, ixs, ixs) } -> std::floating_point;
}
EvalSequence single_term_opt_impl(TensorNetwork const& network,
meta::range_of<Index> auto const& tidxs,
CostFn&& cost_fn) {
CostFn&& cost_fn, bool subnet_cse) {
using ranges::views::concat;
using ranges::views::indirect;
using ranges::views::transform;
Expand Down Expand Up @@ -88,26 +154,96 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
}
}

// precompute all subnet_meta if subnet_cse is true
// Note: the O(2^n) cost is bounded in practice — subset_target_indices above
// asserts n <= 24, capping the number of subsets at ~16M.
container::vector<size_t> meta_ids;
container::vector<double> unique_meta_costs;
if (subnet_cse) {
// Use max as sentinel for entries with popcount < 2 (singletons/empty),
// which are skipped below and never assigned a real meta ID.
meta_ids.resize(results.size(), std::numeric_limits<size_t>::max());
container::unordered_map<TensorNetwork::SlotCanonicalizationMetadata,
size_t, SubNetHash, SubNetEqual>
meta_to_id;

for (size_t n = 0; n < results.size(); ++n) {
if (std::popcount(n) < 2) continue;
auto ts = bits::on_bits_index(n) | bits::sieve(network.tensors());
container::vector<ExprPtr> ts_expr;
for (auto&& t : ts) {
ts_expr.emplace_back(std::dynamic_pointer_cast<Tensor>(t)->clone());
}
auto tn = TensorNetwork{ts_expr};
auto meta = tn.canonicalize_slots(
TensorCanonicalizer::cardinal_tensor_labels(), &results[n].indices);

auto [it, inserted] = meta_to_id.try_emplace(std::move(meta), 0);
if (inserted) {
it->second = meta_to_id.size() - 1;
}
meta_ids[n] = it->second;
}
unique_meta_costs.resize(meta_to_id.size(), 0.0);
}

// find the optimal evaluation sequence
for (size_t n = 0; n < results.size(); ++n) {
if (std::popcount(n) < 2) continue;
std::pair<size_t, size_t> curr_parts{0, 0};
for (auto& curr_cost = results[n].flops;
auto&& [lp, rp] : bits::bipartitions(n)) {
// do nothing with the trivial bipartition
// i.e. one subset is the empty set and the other full
if (lp == 0 || rp == 0) continue;
auto new_cost = std::forward<CostFn>(cost_fn)(results[lp].indices, //
results[rp].indices, //
results[n].indices) //
+ results[lp].flops + results[rp].flops;

double new_cost = 0;
container::vector<size_t> combined_subnets;
if (subnet_cse) {
// subnets is always kept sorted; set_union requires sorted inputs and
// produces sorted output — this invariant is maintained throughout.
std::set_union(results[lp].subnets.begin(), results[lp].subnets.end(),
results[rp].subnets.begin(), results[rp].subnets.end(),
std::back_inserter(combined_subnets));
new_cost = cost_fn(results[lp].indices, //
results[rp].indices, //
results[n].indices);
for (auto id : combined_subnets) {
new_cost += unique_meta_costs[id];
}
} else {
new_cost = cost_fn(results[lp].indices, //
results[rp].indices, //
results[n].indices) //
+ results[lp].flops + results[rp].flops;
}

if (new_cost <= curr_cost) {
curr_cost = new_cost;
curr_parts = decltype(curr_parts){lp, rp};
results[n].lp = lp;
results[n].rp = rp;
if (subnet_cse) {
results[n].subnets = std::move(combined_subnets);
}
}
}
auto const& lseq = results[curr_parts.first].sequence;
auto const& rseq = results[curr_parts.second].sequence;

if (subnet_cse) {
auto mid = meta_ids[n];
// Canonically equivalent subnetworks share the same topology and index
// sizes, so their cost is identical. Overwriting with a later bitmask's
// cost is intentional and benign.
unique_meta_costs[mid] =
cost_fn(results[results[n].lp].indices,
results[results[n].rp].indices, results[n].indices);
auto it = std::lower_bound(results[n].subnets.begin(),
results[n].subnets.end(), mid);
if (it == results[n].subnets.end() || *it != mid) {
results[n].subnets.insert(it, mid);
}
}

auto const& lseq = results[results[n].lp].sequence;
auto const& rseq = results[results[n].rp].sequence;
results[n].sequence =
(lseq[0] < rseq[0] ? concat(lseq, rseq) : concat(rseq, lseq)) |
ranges::to<EvalSequence>;
Expand All @@ -121,15 +257,19 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
/// \tparam IdxToSz
/// \param network A TensorNetwork object.
/// \param idxsz An invocable on Index, that maps Index to its dimension.
/// \return Optimal evaluation sequence that minimizes flops. If there are
/// \param subnet_cse Whether to recognize equivalent subnetworks to try
/// minimizing the ops counts.
/// \return Optimal evaluation sequence that
/// minimizes flops. If there are
/// equivalent optimal sequences then the result is the one that keeps
/// the order of tensors in the network as original as possible.
///
template <has_index_extent IdxToSz>
EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz,
bool subnet_cse) {
auto cost_fn = flops_counter(std::forward<IdxToSz>(idxsz));
decltype(OptRes::indices) tidxs{};
return single_term_opt_impl(network, tidxs, cost_fn);
return single_term_opt_impl(network, tidxs, cost_fn, subnet_cse);
}

} // namespace detail
Expand All @@ -142,7 +282,8 @@ EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
/// @note @c prod is assumed to consist of only Tensor expressions
///
template <has_index_extent IdxToSz>
ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz,
bool subnet_cse = false) {
using ranges::views::filter;
using ranges::views::reverse;

Expand All @@ -152,7 +293,7 @@ ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
auto const tensors =
prod | filter(&ExprPtr::template is<Tensor>) | ranges::to_vector;
auto seq = detail::single_term_opt(TensorNetwork{tensors},
std::forward<IdxToSz>(idxsz));
std::forward<IdxToSz>(idxsz), subnet_cse);
auto result = container::svector<ExprPtr>{};
for (auto i : seq)
if (i == -1) {
Expand Down
103 changes: 98 additions & 5 deletions tests/unit/test_optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <cstddef>
#include <initializer_list>
#include <memory>
#include <stdexcept>

#include <range/v3/all.hpp>

Expand Down Expand Up @@ -323,9 +322,103 @@ TEST_CASE("optimize", "[optimize]") {
}
}

SECTION("Single term optimization with CSE") {
auto ctx_resetter =
set_scoped_default_context(get_default_context().clone());
auto reg = get_default_context().mutable_index_space_registry();
mbpt::add_df_spaces(reg);
mbpt::add_pao_spaces(reg);
mbpt::add_ao_spaces(reg);
// i 10
// a 40
// μ̃ 50
// Κ 90
for (auto&& [k, v] :
std::initializer_list<std::pair<std::wstring_view, size_t>>{
{L"i", 10}, {L"a", 40}, {L"μ̃", 50}, {L"Κ", 90}}) {
reg->retrieve_ptr(k)->approximate_size(v);
}

auto single_term_opt = [](Product const& prod, bool cse = true) {
return opt::single_term_opt(
prod,
[](Index const& ix) {
// null space contributes x1 to the size
auto sz = ix.nonnull() ? ix.space().approximate_size() : 1;
return sz;
},
/*subnet_cse=*/cse);
};

auto prod9 =
deserialize("X{i1;a1} X{i2;a2} Y{a2;i3} Y{a1;i4}")->as<Product>();
auto res9 = single_term_opt(prod9);
auto res9_no_cse = single_term_opt(prod9, false);
// this is the one we want to find
// (X Y) (X Y)
REQUIRE(extract(res9, {0, 0}) == prod9.at(0));
REQUIRE(extract(res9, {0, 1}) == prod9.at(3));
REQUIRE(extract(res9, {1, 0}) == prod9.at(1));
REQUIRE(extract(res9, {1, 1}) == prod9.at(2));

// take a look at res9_no_cse for a result with subnet_cse disabled
// should give the same result in this case as it's already optimal
REQUIRE(extract(res9_no_cse, {0, 0}) == prod9.at(0));
REQUIRE(extract(res9_no_cse, {0, 1}) == prod9.at(3));
REQUIRE(extract(res9_no_cse, {1, 0}) == prod9.at(1));
REQUIRE(extract(res9_no_cse, {1, 1}) == prod9.at(2));

SECTION("CSE effect on optimization result") {
auto ctx_resetter =
set_scoped_default_context(get_default_context().clone());
auto reg = get_default_context().mutable_index_space_registry();
// Use sizes that make the unbalanced tree better without CSE,
// but the balanced tree better with CSE.
// Balanced: ( (X1 Y1) (X2 Y2) )
// Cost(X1*Y1) = size(i)*size(a)*size(j) = 12*10*12 = 1440.
// Cost(Inter) = 12^3 = 1728.
// Total no-CSE: 2*1440 + 1728 = 4608.
// Total CSE: 1440 + 1728 = 3168.
// Unbalanced: ( ( (X1 Y1) X2 ) Y2 )
// Cost(X1*Y1) = 12*10*12 = 1440.
// Cost((X1*Y1)*X2) = size(i)*size(i)*size(a) = 12*12*10 = 1440.
// Cost(...) * Y2 = 12*10*12 = 1440.
// Total Unbalanced: 1440 + 1440 + 1440 = 4320.
// 3168 < 4320 < 4608.
reg->retrieve_ptr(L"i")->approximate_size(12);
reg->retrieve_ptr(L"a")->approximate_size(10);

auto single_term_opt = [](Product const& prod, bool cse) {
return opt::single_term_opt(
prod,
[](Index const& ix) {
return ix.nonnull() ? ix.space().approximate_size() : 1;
},
cse);
};

// X{i1;a1} Y{a1;i2} X{i2;a2} Y{a2;i3}
auto prod =
deserialize(L"X{i1;a1} Y{a1;i2} X{i2;a2} Y{a2;i3}")->as<Product>();

auto res_cse = single_term_opt(prod, true);
auto res_no_cse = single_term_opt(prod, false);

// With CSE: Balanced tree
REQUIRE(res_cse->as<Product>().factors().size() == 2);
REQUIRE(res_cse->at(0)->is<Product>());
REQUIRE(res_cse->at(1)->is<Product>());

// Without CSE: Unbalanced tree
bool is_unbalanced =
(res_no_cse->at(0)->is<Tensor>() || res_no_cse->at(1)->is<Tensor>());
REQUIRE(is_unbalanced);
}
}

/// verify that space changes did not leak
auto reg = get_default_context().index_space_registry();
auto uocc = reg->retrieve_ptr(L"a");
REQUIRE(uocc);
REQUIRE(uocc->approximate_size() == 10);
auto reg_check = get_default_context().index_space_registry();
auto uocc_check = reg_check->retrieve_ptr(L"a");
REQUIRE(uocc_check);
REQUIRE(uocc_check->approximate_size() == 10);
}