Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 146 additions & 15 deletions SeQuant/core/optimize/single_term.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
#include <SeQuant/core/tensor_network.hpp>
#include <SeQuant/core/utility/indices.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/external/bliss/graph.hh>

#include <range/v3/view.hpp>

#include <SeQuant/core/algorithm.hpp>
#include <SeQuant/core/tensor_canonicalizer.hpp>
#include <algorithm>
#include <bit>
#include <type_traits>

Expand Down Expand Up @@ -42,22 +46,83 @@ auto constexpr flops_counter(has_index_extent auto&& ixex) {
///
struct OptRes {
/// Free indices remaining upon evaluation
container::svector<sequant::Index> indices;
IndexSet indices;

/// The flops count of evaluation
double flops;

/// The evaluation sequence
EvalSequence sequence;

/// Bitmask splits that resulted into this OptRes
size_t lp = 0;
size_t rp = 0;

/// unique canonical subnets in the optimal tree for this bitmask
container::vector<uint16_t> subnets;
};

struct SubNetHash {
size_t operator()(
TensorNetwork::SlotCanonicalizationMetadata const& data) const noexcept {
return data.hash_value();
}
};

struct SubNetEqual {
bool operator()(
TensorNetwork::SlotCanonicalizationMetadata const& left,
TensorNetwork::SlotCanonicalizationMetadata const& right) const {
return bliss::ConstGraphCmp::cmp(*left.graph, *right.graph) == 0;
}
};

/// \brief Finds the optimal evaluation sequence for a single-term tensor
/// contraction.
///
/// This function employs an exhaustive search using dynamic programming to
/// determine the contraction order that minimizes the total cost, as defined by
/// the provided cost function.
///
/// \tparam CostFn A function object type that computes the cost of a single
/// binary contraction.
/// Expected signature:
// \code double(meta::range_of<Index> auto const& lhs,
// meta::range_of<Index> auto const& rhs,
/// meta::range_of<Index> auto const& res)
// \endcode
///
/// \param network The \ref TensorNetwork containing the tensors to be
/// contracted.
// \param tidxs The set of indices that should remain open in the
/// final result.
// \param cost_fn The cost model used to evaluate contractions
/// (e.g., flop count).
// \param subnet_cse If true, enables Common Subexpression
/// Elimination (CSE) for
/// equivalent subnetworks. When enabled, the cost of
/// evaluating structurally identical subnetworks is counted
/// only once in the total cost of a contraction tree.
/// Equivalence is determined by canonicalizing the subnetwork
/// graph.
///
/// \return An \ref EvalSequence representing the optimal contraction order.
///
/// \details The optimization uses a bitmask-based dynamic programming approach
/// where each state represents a subnetwork (subset of tensors).
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
/// metadata for every possible subnetwork to identify common
/// structures. This allows it to find trees that benefit from reusing
/// intermediate results, which is particularly effective for
/// expressions with repeating tensor patterns.
///
template <typename CostFn>
requires requires(CostFn&& fn, decltype(OptRes::indices) const& ixs) {
{ std::forward<CostFn>(fn)(ixs, ixs, ixs) } -> std::floating_point;
}
EvalSequence single_term_opt_impl(TensorNetwork const& network,
meta::range_of<Index> auto const& tidxs,
CostFn&& cost_fn) {
CostFn&& cost_fn, bool subnet_cse) {
using ranges::views::concat;
using ranges::views::indirect;
using ranges::views::transform;
Expand Down Expand Up @@ -88,26 +153,87 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
}
}

// precompute all subnet_meta if subnet_cse is true
container::vector<uint16_t> meta_ids;
container::vector<double> unique_meta_costs;
if (subnet_cse) {
meta_ids.resize(results.size(), 0);
container::map<TensorNetwork::SlotCanonicalizationMetadata, uint16_t,
SubNetEqual>
meta_to_id;

for (size_t n = 0; n < results.size(); ++n) {
if (std::popcount(n) < 2) continue;
auto ts = bits::on_bits_index(n) | bits::sieve(network.tensors());
container::vector<ExprPtr> ts_expr;
for (auto&& t : ts) {
ts_expr.emplace_back(std::dynamic_pointer_cast<Tensor>(t)->clone());
}
auto tn = TensorNetwork{ts_expr};
auto meta = tn.canonicalize_slots(
TensorCanonicalizer::cardinal_tensor_labels(), &results[n].indices);

auto [it, inserted] = meta_to_id.try_emplace(std::move(meta), 0);
if (inserted) {
it->second = meta_to_id.size() - 1;
}
meta_ids[n] = it->second;
}
unique_meta_costs.resize(meta_to_id.size(), 0.0);
}

// find the optimal evaluation sequence
for (size_t n = 0; n < results.size(); ++n) {
if (std::popcount(n) < 2) continue;
std::pair<size_t, size_t> curr_parts{0, 0};
for (auto& curr_cost = results[n].flops;
auto&& [lp, rp] : bits::bipartitions(n)) {
// do nothing with the trivial bipartition
// i.e. one subset is the empty set and the other full
if (lp == 0 || rp == 0) continue;
auto new_cost = std::forward<CostFn>(cost_fn)(results[lp].indices, //
results[rp].indices, //
results[n].indices) //
+ results[lp].flops + results[rp].flops;

double new_cost = 0;
container::vector<uint16_t> combined_subnets;
if (subnet_cse) {
std::set_union(results[lp].subnets.begin(), results[lp].subnets.end(),
results[rp].subnets.begin(), results[rp].subnets.end(),
std::back_inserter(combined_subnets));
new_cost = std::forward<CostFn>(cost_fn)(results[lp].indices, //
results[rp].indices, //
results[n].indices);
for (auto id : combined_subnets) {
new_cost += unique_meta_costs[id];
}
} else {
new_cost = std::forward<CostFn>(cost_fn)(results[lp].indices, //
results[rp].indices, //
results[n].indices) //
+ results[lp].flops + results[rp].flops;
}

if (new_cost <= curr_cost) {
curr_cost = new_cost;
curr_parts = decltype(curr_parts){lp, rp};
results[n].lp = lp;
results[n].rp = rp;
if (subnet_cse) {
results[n].subnets = std::move(combined_subnets);
}
}
}
auto const& lseq = results[curr_parts.first].sequence;
auto const& rseq = results[curr_parts.second].sequence;

if (subnet_cse) {
auto mid = meta_ids[n];
unique_meta_costs[mid] = std::forward<CostFn>(cost_fn)(
results[results[n].lp].indices, results[results[n].rp].indices,
results[n].indices);
auto it = std::lower_bound(results[n].subnets.begin(),
results[n].subnets.end(), mid);
if (it == results[n].subnets.end() || *it != mid) {
results[n].subnets.insert(it, mid);
}
}

auto const& lseq = results[results[n].lp].sequence;
auto const& rseq = results[results[n].rp].sequence;
results[n].sequence =
(lseq[0] < rseq[0] ? concat(lseq, rseq) : concat(rseq, lseq)) |
ranges::to<EvalSequence>;
Expand All @@ -121,15 +247,19 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
/// \tparam IdxToSz
/// \param network A TensorNetwork object.
/// \param idxsz An invocable on Index, that maps Index to its dimension.
/// \return Optimal evaluation sequence that minimizes flops. If there are
/// \param subnet_cse Whether to recognize equivalent subnetworks to try
/// minimizing the ops counts.
/// \return Optimal evaluation sequence that
/// minimizes flops. If there are
/// equivalent optimal sequences then the result is the one that keeps
/// the order of tensors in the network as original as possible.
///
template <has_index_extent IdxToSz>
EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz,
bool subnet_cse) {
auto cost_fn = flops_counter(std::forward<IdxToSz>(idxsz));
decltype(OptRes::indices) tidxs{};
return single_term_opt_impl(network, tidxs, cost_fn);
return single_term_opt_impl(network, tidxs, cost_fn, subnet_cse);
}

} // namespace detail
Expand All @@ -142,7 +272,8 @@ EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
/// @note @c prod is assumed to consist of only Tensor expressions
///
template <has_index_extent IdxToSz>
ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz,
bool subnet_cse = false) {
using ranges::views::filter;
using ranges::views::reverse;

Expand All @@ -152,7 +283,7 @@ ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
auto const tensors =
prod | filter(&ExprPtr::template is<Tensor>) | ranges::to_vector;
auto seq = detail::single_term_opt(TensorNetwork{tensors},
std::forward<IdxToSz>(idxsz));
std::forward<IdxToSz>(idxsz), subnet_cse);
auto result = container::svector<ExprPtr>{};
for (auto i : seq)
if (i == -1) {
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,53 @@ TEST_CASE("optimize", "[optimize]") {
}
}
}
SECTION("Single term optimization with CSE") {
auto ctx_resetter =
set_scoped_default_context(get_default_context().clone());
auto reg = get_default_context().mutable_index_space_registry();
mbpt::add_df_spaces(reg);
mbpt::add_pao_spaces(reg);
mbpt::add_ao_spaces(reg);
// i 10
// a 40
// μ̃ 50
// Κ 90
for (auto&& [k, v] :
std::initializer_list<std::pair<std::wstring_view, size_t>>{
{L"i", 10}, {L"a", 40}, {L"μ̃", 50}, {L"Κ", 90}}) {
reg->retrieve_ptr(k)->approximate_size(v);
}

auto single_term_opt = [](Product const& prod, bool cse = true) {
return opt::single_term_opt(
prod,
[](Index const& ix) {
// null space contributes x1 to the size
auto sz = ix.nonnull() ? ix.space().approximate_size() : 1;
return sz;
},
/*subnet_cse=*/cse);
};

auto prod9 =
deserialize("X{i1;a1} X{i2;a4} Y{a3;i3} Y{a1;i4}")->as<Product>();
auto res9 = single_term_opt(prod9);
// this is the one we want to find
// (X Y) (X Y)
REQUIRE(extract(res9, {0, 0}) == prod9.at(0));
REQUIRE(extract(res9, {0, 1}) == prod9.at(3));
REQUIRE(extract(res9, {1, 0}) == prod9.at(1));
REQUIRE(extract(res9, {1, 1}) == prod9.at(2));

// take a look at res9_ for a result with subnet_cse disabled
// should give different result
// std::wcout << "res9_\n" << serialize(res9_) << std::endl;
auto res9_no_cse = single_term_opt(prod9, false);
REQUIRE(extract(res9_no_cse, {0, 0, 0}) == prod9.at(0));
REQUIRE(extract(res9_no_cse, {0, 0, 1}) == prod9.at(3));
REQUIRE(extract(res9_no_cse, {0, 1}) == prod9.at(1));
REQUIRE(extract(res9_no_cse, {1}) == prod9.at(2));
}

/// verify that space changes did not leak
auto reg = get_default_context().index_space_registry();
Expand Down