Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/xfac/tensor/tensor_ci_2.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ struct TensorCI2 {
{}

/// constructs a TensorCI2 from a given tensor train, and using the provided function f.
TensorCI2(function<T(vector<int>)> f_, TensorTrain<T> tt_, TensorCI2Param param_={})
: f {f_, param_.useCachedFunction}
TensorCI2(TensorFunction<T> const& f_, TensorTrain<T> tt_, TensorCI2Param param_={})
: f {f_}
, param(param_)
, pivotError(1)
, Iset {tt_.M.size()}
Expand Down Expand Up @@ -124,6 +124,11 @@ struct TensorCI2 {
iterate(1,0); // just to define tt, while reevaluating the original f in the pivots.
}

/// constructs a TensorCI2 from a given tensor train, and using the provided function f.
TensorCI2(function<T(vector<int>)> f_, TensorTrain<T> tt_, TensorCI2Param param_={})
: TensorCI2(TensorFunction<T> {f_, param_.useCachedFunction}, tt_, param_)
{}

/// constructs a TensorCI2 from a given tensor train. It takes the tt as a true function f.
TensorCI2(TensorTrain<T> const& tt_, TensorCI2Param param_={}) : TensorCI2(tt_, tt_, param_) {}

Expand Down Expand Up @@ -399,6 +404,12 @@ class QTensorCI: public TensorCI2<T> {
, grid {grid_}
{}

/// constructs a QTensorCI from a function f:(u1,u2,...,un)->T, a tensor train and the given quantics grid
QTensorCI(function<T(vector<double>)> f_, grid::Quantics grid_, TensorTrain<T> tt_, TensorCI2Param par={})
: TensorCI2<T> {tensorFun(f_,grid_), tt_, par}
, grid {grid_}
{}

/// returns the underline quantics tensor train
QTensorTrain<T> get_qtt() const { return {this->tt, grid}; }

Expand Down
34 changes: 34 additions & 0 deletions test/test_tensor_ci_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,40 @@ TEST_CASE("quantics 2")
cout << "true error=" << ci.trueError(1<<ci.grid.nBit) << "\n";
}

SECTION( "restart from tensor train" )
{
function f2=[&](double x) { return pow(x+0.1,5)+pow(x-0.5,4); };
auto exact = pow(1.1,6)/6-pow(0.1,6)/6 + pow(0.5,5)/5+pow(0.5,5)/5;

map<double,double> cache;
auto f=[&](vector<double> y) { double x = y[0]; return cache[x]=f2(x); };

auto ci=QTensorCI<double>(f, grid::Quantics {0.0, 1.0, 30}, {.bondDim=20, .useCachedFunction=true});
ci.iterate(3);

auto ci_2=QTensorCI<double>(f, grid::Quantics {0.0, 1.0, 30}, ci.get_qtt().tt, {.bondDim=20, .useCachedFunction=true});
//ci_2.iterate();
auto res = ci_2.get_qtt().integral();
REQUIRE(std::abs(res - exact)<1e-9);
}

SECTION( "restart from tensor train complex" )
{
function f2=[&](double x) { return pow(x+0.1,5)+pow(x-0.5,4); };
auto exact = pow(1.1,6)/6-pow(0.1,6)/6 + pow(0.5,5)/5+pow(0.5,5)/5;

auto f=[&](vector<double> y) { return cmpx(f2(y[0]),f2(y[1])); };

auto ci=QTensorCI<cmpx>(f, grid::Quantics {0.0, 1.0, 40, 2, true});
ci.iterate(5);
ci.get_qtt().save("my_qtt.dat");

auto qtt=QTensorTrain<cmpx>::load("my_qtt.dat");
auto ci_2=QTensorCI<cmpx>(f, qtt.grid, qtt.tt);
auto res = ci_2.get_qtt().integral();
REQUIRE(std::abs(res - cmpx(exact,exact))<1e-9);
}

SECTION( "copy pivots quantics" )
{
function f2=[&](double x) { return pow(x+0.1,5)+pow(x-0.5,4); };
Expand Down