Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions include/xfac/tensor/tensor_ci_2.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ struct TensorCI2 {
{}

/// constructs a TensorCI2 from a given tensor train, and using the provided function f.
TensorCI2(function<T(vector<int>)> f_, TensorTrain<T> tt_, TensorCI2Param param_={})
: f {f_, param_.useCachedFunction}
TensorCI2(TensorFunction<T> const& f_, TensorTrain<T> tt_, TensorCI2Param param_={})
: f {f_}
, param(param_)
, pivotError(1)
, Iset {tt_.M.size()}
Expand Down Expand Up @@ -124,6 +124,11 @@ struct TensorCI2 {
iterate(1,0); // just to define tt, while reevaluating the original f in the pivots.
}

/// constructs a TensorCI2 from a given tensor train, and using the provided function f.
TensorCI2(function<T(vector<int>)> f_, TensorTrain<T> tt_, TensorCI2Param param_={})
: TensorCI2(TensorFunction<T> {f_, param_.useCachedFunction}, tt_, param_)
{}

/// constructs a TensorCI2 from a given tensor train. It takes the tt as a true function f.
TensorCI2(TensorTrain<T> const& tt_, TensorCI2Param param_={}) : TensorCI2(tt_, tt_, param_) {}

Expand Down Expand Up @@ -387,6 +392,8 @@ class QTensorCI: public TensorCI2<T> {
public:
grid::Quantics grid;

QTensorCI() = default;

/// constructs a rank-1 QTensorCI from a function f:(u1,u2,...,un)->T and the given quantics grid
QTensorCI(function<T(vector<double>)> f_, grid::Quantics grid_, TensorCI2Param par={})
: TensorCI2<T> {tensorFun(f_,grid_), grid_.tensorDims(), par}
Expand All @@ -399,6 +406,12 @@ class QTensorCI: public TensorCI2<T> {
, grid {grid_}
{}

/// constructs a rank-1 QTensorCI from a function f:(u1,u2,...,un)->T, a tensor train and the given quantics grid
QTensorCI(function<T(vector<double>)> f_, grid::Quantics grid_, TensorTrain<T> tt_, TensorCI2Param par={})
: TensorCI2<T> {tensorFun(f_,grid_), tt_, par}
, grid {grid_}
{}

/// returns the underline quantics tensor train
QTensorTrain<T> get_qtt() const { return {this->tt, grid}; }

Expand Down