diff --git a/include/xfac/tensor/tensor_ci_2.h b/include/xfac/tensor/tensor_ci_2.h index 5dfee6c..0f1ba31 100644 --- a/include/xfac/tensor/tensor_ci_2.h +++ b/include/xfac/tensor/tensor_ci_2.h @@ -79,8 +79,8 @@ struct TensorCI2 { {} /// constructs a TensorCI2 from a given tensor train, and using the provided function f. - TensorCI2(function)> f_, TensorTrain tt_, TensorCI2Param param_={}) - : f {f_, param_.useCachedFunction} + TensorCI2(TensorFunction const& f_, TensorTrain tt_, TensorCI2Param param_={}) + : f {f_} , param(param_) , pivotError(1) , Iset {tt_.M.size()} @@ -124,6 +124,11 @@ struct TensorCI2 { iterate(1,0); // just to define tt, while reevaluating the original f in the pivots. } + /// constructs a TensorCI2 from a given tensor train, and using the provided function f. + TensorCI2(function)> f_, TensorTrain tt_, TensorCI2Param param_={}) + : TensorCI2(TensorFunction {f_, param_.useCachedFunction}, tt_, param_) + {} + /// constructs a TensorCI2 from a given tensor train. It takes the tt as a true function f. TensorCI2(TensorTrain const& tt_, TensorCI2Param param_={}) : TensorCI2(tt_, tt_, param_) {} @@ -399,6 +404,12 @@ class QTensorCI: public TensorCI2 { , grid {grid_} {} + /// constructs a QTensorCI from a function f:(u1,u2,...,un)->T, a tensor train and the given quantics grid + QTensorCI(function)> f_, grid::Quantics grid_, TensorTrain tt_, TensorCI2Param par={}) + : TensorCI2 {tensorFun(f_,grid_), tt_, par} + , grid {grid_} + {} + /// returns the underline quantics tensor train QTensorTrain get_qtt() const { return {this->tt, grid}; } diff --git a/test/test_tensor_ci_2.cpp b/test/test_tensor_ci_2.cpp index 03e1b80..f0f6ed0 100644 --- a/test/test_tensor_ci_2.cpp +++ b/test/test_tensor_ci_2.cpp @@ -143,6 +143,40 @@ TEST_CASE("quantics 2") cout << "true error=" << ci.trueError(1< cache; + auto f=[&](vector y) { double x = y[0]; return cache[x]=f2(x); }; + + auto ci=QTensorCI(f, grid::Quantics {0.0, 1.0, 30}, {.bondDim=20, .useCachedFunction=true}); + ci.iterate(3); + + auto ci_2=QTensorCI(f, grid::Quantics {0.0, 1.0, 30}, ci.get_qtt().tt, {.bondDim=20, .useCachedFunction=true}); + //ci_2.iterate(); + auto res = ci_2.get_qtt().integral(); + REQUIRE(std::abs(res - exact)<1e-9); + } + + SECTION( "restart from tensor train complex" ) + { + function f2=[&](double x) { return pow(x+0.1,5)+pow(x-0.5,4); }; + auto exact = pow(1.1,6)/6-pow(0.1,6)/6 + pow(0.5,5)/5+pow(0.5,5)/5; + + auto f=[&](vector y) { return cmpx(f2(y[0]),f2(y[1])); }; + + auto ci=QTensorCI(f, grid::Quantics {0.0, 1.0, 40, 2, true}); + ci.iterate(5); + ci.get_qtt().save("my_qtt.dat"); + + auto qtt=QTensorTrain::load("my_qtt.dat"); + auto ci_2=QTensorCI(f, qtt.grid, qtt.tt); + auto res = ci_2.get_qtt().integral(); + REQUIRE(std::abs(res - cmpx(exact,exact))<1e-9); + } + SECTION( "copy pivots quantics" ) { function f2=[&](double x) { return pow(x+0.1,5)+pow(x-0.5,4); };