Skip to content

Commit 0cc0944

Browse files
committed
cleanup
1 parent ae413af commit 0cc0944

2 files changed

Lines changed: 31 additions & 45 deletions

File tree

SeQuant/core/optimize/single_term.hpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,30 +95,35 @@ struct SubNetEqual {
9595
///
9696
/// \tparam CostFn A function object type that computes the cost of a single
9797
/// binary contraction.
98-
/// Expected signature: \code double(meta::range_of<Index> auto
99-
/// const& lhs, meta::range_of<Index> auto const& rhs,
100-
/// meta::range_of<Index> auto const& res) \endcode
98+
/// Expected signature:
99+
// \code double(meta::range_of<Index> auto const& lhs,
100+
// meta::range_of<Index> auto const& rhs,
101+
/// meta::range_of<Index> auto const& res)
102+
// \endcode
101103
///
102104
/// \param network The \ref TensorNetwork containing the tensors to be
103-
/// contracted. \param tidxs The set of indices that should remain open in the
104-
/// final result. \param cost_fn The cost model used to evaluate contractions
105-
/// (e.g., flop count). \param subnet_cse If true, enables Common Subexpression
105+
/// contracted.
106+
// \param tidxs The set of indices that should remain open in the
107+
/// final result.
108+
// \param cost_fn The cost model used to evaluate contractions
109+
/// (e.g., flop count).
110+
// \param subnet_cse If true, enables Common Subexpression
106111
/// Elimination (CSE) for
107-
/// equivalent subnetworks. When enabled, the cost of
108-
/// evaluating structurally identical subnetworks is counted
109-
/// only once in the total cost of a contraction tree.
110-
/// Equivalence is determined by canonicalizing the subnetwork
111-
/// graph.
112+
/// equivalent subnetworks. When enabled, the cost of
113+
/// evaluating structurally identical subnetworks is counted
114+
/// only once in the total cost of a contraction tree.
115+
/// Equivalence is determined by canonicalizing the subnetwork
116+
/// graph.
112117
///
113118
/// \return An \ref EvalSequence representing the optimal contraction order.
114119
///
115120
/// \details The optimization uses a bitmask-based dynamic programming approach
116-
/// where each state represents a subnetwork (subset of tensors).
117-
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
118-
/// metadata for every possible subnetwork to identify common
119-
/// structures. This allows it to find trees that benefit from reusing
120-
/// intermediate results, which is particularly effective for
121-
/// expressions with repeating tensor patterns.
121+
/// where each state represents a subnetwork (subset of tensors).
122+
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
123+
/// metadata for every possible subnetwork to identify common
124+
/// structures. This allows it to find trees that benefit from reusing
125+
/// intermediate results, which is particularly effective for
126+
/// expressions with repeating tensor patterns.
122127
///
123128
template <typename CostFn>
124129
requires requires(CostFn&& fn, decltype(OptRes::indices) const& ixs) {

tests/unit/test_optimize.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,21 @@ TEST_CASE("optimize", "[optimize]") {
353353
auto prod9 =
354354
deserialize("X{i1;a1} X{i2;a4} Y{a3;i3} Y{a1;i4}")->as<Product>();
355355
auto res9 = single_term_opt(prod9);
356-
// take a look at res9_ for a result with subnet_cse disabled
357-
// should give different result
358-
// auto res9_ = single_term_opt(prod9, false);
359-
// std::wcout << "res9_\n" << serialize(res9_) << std::endl;
360356
// this is the one we want to find
361357
// (X Y) (X Y)
362358
REQUIRE(extract(res9, {0, 0}) == prod9.at(0));
363359
REQUIRE(extract(res9, {0, 1}) == prod9.at(3));
364360
REQUIRE(extract(res9, {1, 0}) == prod9.at(1));
365361
REQUIRE(extract(res9, {1, 1}) == prod9.at(2));
362+
363+
// take a look at res9_ for a result with subnet_cse disabled
364+
// should give different result
365+
// std::wcout << "res9_\n" << serialize(res9_) << std::endl;
366+
auto res9_no_cse = single_term_opt(prod9, false);
367+
REQUIRE(extract(res9_no_cse, {0, 0, 0}) == prod9.at(0));
368+
REQUIRE(extract(res9_no_cse, {0, 0, 1}) == prod9.at(3));
369+
REQUIRE(extract(res9_no_cse, {0, 1}) == prod9.at(1));
370+
REQUIRE(extract(res9_no_cse, {1}) == prod9.at(2));
366371
}
367372

368373
/// verify that space changes did not leak
@@ -371,27 +376,3 @@ TEST_CASE("optimize", "[optimize]") {
371376
REQUIRE(uocc);
372377
REQUIRE(uocc->approximate_size() == 10);
373378
}
374-
375-
TEST_CASE("feature optimize", "[feature]") {
376-
using namespace sequant;
377-
auto ctx_resetter = set_scoped_default_context(get_default_context().clone());
378-
auto reg = get_default_context().mutable_index_space_registry();
379-
mbpt::add_df_spaces(reg);
380-
mbpt::add_pao_spaces(reg);
381-
mbpt::add_ao_spaces(reg);
382-
// i 10
383-
// a 40
384-
// μ̃ 50
385-
// Κ 90
386-
for (auto&& [k, v] :
387-
std::initializer_list<std::pair<std::wstring_view, size_t>>{
388-
{L"i", 10}, {L"a", 40}, {L"μ̃", 50}, {L"Κ", 90}}) {
389-
reg->retrieve_ptr(k)->approximate_size(v);
390-
}
391-
392-
for (auto&& ix :
393-
std::initializer_list<std::wstring_view>{L"i", L"a", L"μ̃", L"Κ"}) {
394-
std::wcout << std::format(L"{}: {}\n", ix,
395-
reg->retrieve_ptr(ix)->approximate_size());
396-
}
397-
}

0 commit comments

Comments
 (0)