Skip to content

Commit

Permalink
Add missing function
Browse files Browse the repository at this point in the history
Turn friend class into a template class
  • Loading branch information
williamfgc committed Jul 25, 2023
1 parent 9775e31 commit 6d3edef
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
28 changes: 28 additions & 0 deletions src/QMCWaveFunctions/SPOSetT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,18 @@ void SPOSetT<T>::mw_evaluate_notranspose(const RefVectorWithLeader<SPOSetT<T>>&
spo_list[iw].evaluate_notranspose(P_list[iw], first, last, logdet_list[iw], dlogdet_list[iw], d2logdet_list[iw]);
}

template<class T>
void SPOSetT<T>::evaluate_notranspose(const ParticleSet& P,
int first,
int last,
ValueMatrix& logdet,
GradMatrix& dlogdet,
HessMatrix& grad_grad_logdet,
GGGMatrix& grad_grad_grad_logdet)
{
throw std::runtime_error("Need specialization of SPOSet::evaluate_notranspose() for grad_grad_grad_logdet. \n");
}

template<class T>
std::unique_ptr<SPOSetT<T>> SPOSetT<T>::makeClone() const
{
Expand Down Expand Up @@ -370,6 +382,22 @@ void SPOSetT<T>::evaluateGradSource(const ParticleSet& P,
"must be overloaded when the SPOSet has ion derivatives.");
}

template<class T>
void SPOSetT<T>::evaluateGradSource(const ParticleSet& P,
int first,
int last,
const ParticleSet& source,
int iat_src,
GradMatrix& grad_phi,
HessMatrix& grad_grad_phi,
GradMatrix& grad_lapl_phi)
{
if (hasIonDerivs())
throw std::logic_error("Bug!! " + getClassName() +
"::evaluateGradSource "
"must be overloaded when the SPOSet has ion derivatives.");
}

template<class T>
void SPOSetT<T>::evaluateGradSourceRow(const ParticleSet& P,
int iel,
Expand Down
14 changes: 5 additions & 9 deletions src/QMCWaveFunctions/SPOSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ template<class T>
class SPOSetT;
namespace testing
{
opt_variables_type& getMyVars(SPOSetT<float>& spo);
opt_variables_type& getMyVars(SPOSetT<double>& spo);
opt_variables_type& getMyVars(SPOSetT<std::complex<float>>& spo);
opt_variables_type& getMyVars(SPOSetT<std::complex<double>>& spo);
template<class T>
opt_variables_type& getMyVars(SPOSetT<T>& spo);
} // namespace testing


Expand Down Expand Up @@ -481,7 +479,7 @@ class SPOSetT : public QMCTraits
int last,
const ParticleSet& source,
int iat_src,
GradMatrix& gradphi);
GradMatrix& gradphi){};

/** evaluate the gradients of values, gradients, laplacians of this single-particle orbital
* for [first,last) target particles with respect to the given source particle
Expand Down Expand Up @@ -565,10 +563,8 @@ class SPOSetT : public QMCTraits
/// Optimizable variables
opt_variables_type myVars;

friend opt_variables_type& testing::getMyVars(SPOSetT<float>& spo);
friend opt_variables_type& testing::getMyVars(SPOSetT<double>& spo);
friend opt_variables_type& testing::getMyVars(SPOSetT<std::complex<float>>& spo);
friend opt_variables_type& testing::getMyVars(SPOSetT<std::complex<double>>& spo);
template<class T>
friend opt_variables_type& testing::getMyVars(SPOSetT<T>& spo);
};

template<class T>
Expand Down
11 changes: 6 additions & 5 deletions src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,12 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]")

namespace testing
{
opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<float>& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<double>& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<std::complex<float>>& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<std::complex<double>>& rot) { return rot.myVars; }

template<class T>
opt_variables_type& getMyVars(SPOSet<T>& rot)
{
return rot.myVars;
}
opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull; }
std::vector<std::vector<QMCTraits::RealType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
} // namespace testing
Expand Down

0 comments on commit 6d3edef

Please sign in to comment.