Skip to content

Commit

Permalink
Specialize functions in RotatedSPOsT
Browse files Browse the repository at this point in the history
Fix function signature
  • Loading branch information
williamfgc committed Aug 28, 2023
1 parent acb8862 commit 9c61923
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 23 deletions.
6 changes: 3 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOsT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,9 @@ void RotatedSPOsT<T>::evaluateDerivatives(ParticleSet& P,
template<typename T>
void RotatedSPOsT<T>::evaluateDerivativesWF(ParticleSet& P,

Check warning on line 976 in src/QMCWaveFunctions/RotatedSPOsT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/RotatedSPOsT.cpp#L976

Added line #L976 was not covered by tests
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const FullRealType& psiCurrent,
const std::vector<T>& Coeff,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
const ValueVector& detValues_up,
Expand Down
7 changes: 4 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOsT.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObject
public:
using IndexType = typename SPOSetT<T>::IndexType;
using RealType = typename SPOSetT<T>::RealType;
using ValueType = typename SPOSetT<T>::ValueType;
using FullRealType = typename SPOSetT<T>::FullRealType;
using ValueVector = typename SPOSetT<T>::ValueVector;
using ValueMatrix = typename SPOSetT<T>::ValueMatrix;
Expand Down Expand Up @@ -200,9 +201,9 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObject

void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const FullRealType& psiCurrent,
const std::vector<T>& Coeff,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
const ValueVector& detValues_up,
Expand Down
86 changes: 73 additions & 13 deletions src/QMCWaveFunctions/SPOSetBuilderT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
#include "SPOSetBuilderT.h"
#include "OhmmsData/AttributeSet.h"
#include <Message/UniformCommunicateError.h>

#ifndef QMC_COMPLEX
#include "QMCWaveFunctions/RotatedSPOsT.h"
#endif
#include "QMCWaveFunctions/RotatedSPOsT.h" // only for real wavefunctions

namespace qmcplusplus
{
Expand Down Expand Up @@ -133,8 +130,8 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createSPOSet(xmlNodePtr cur)
return sposet;

Check warning on line 130 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L130

Added line #L130 was not covered by tests
}

template<typename T>
std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cur)
template<>
std::unique_ptr<SPOSetT<float>> SPOSetBuilderT<float>::createRotatedSPOSet(xmlNodePtr cur)

Check warning on line 134 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L134

Added line #L134 was not covered by tests
{
std::string spo_object_name;
std::string method;
Expand All @@ -143,12 +140,49 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);

std::unique_ptr<SPOSetT<float>> sposet;

Check warning on line 143 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L143

Added line #L143 was not covered by tests
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "sposet")
{
sposet = createSPOSet(element);
}
});

Check warning on line 149 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L149

Added line #L149 was not covered by tests

if (!sposet)
myComm->barrier_and_abort("Rotated SPO needs an SPOset");

if (!sposet->isRotationSupported())
myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet->getName() + "' of type '" +
sposet->getClassName() + "'.");

sposet->storeParamsBeforeRotation();
auto rot_spo = std::make_unique<RotatedSPOsT<float>>(spo_object_name, std::move(sposet));

if (method == "history")
rot_spo->set_use_global_rotation(false);

Check warning on line 162 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L162

Added line #L162 was not covered by tests

#ifdef QMC_COMPLEX
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
#else
std::unique_ptr<SPOSetT<T>> sposet;
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "opt_vars")
{
std::vector<RealType> params;
putContent(params, element);
rot_spo->setRotationParameters(params);
}
});

Check warning on line 171 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L171

Added line #L171 was not covered by tests
return rot_spo;
}

template<>
std::unique_ptr<SPOSetT<double>> SPOSetBuilderT<double>::createRotatedSPOSet(xmlNodePtr cur)

Check warning on line 176 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L176

Added line #L176 was not covered by tests
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);

std::unique_ptr<SPOSetT<double>> sposet;

Check warning on line 185 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L185

Added line #L185 was not covered by tests
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "sposet")
{
Expand All @@ -164,7 +198,7 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
sposet->getClassName() + "'.");

sposet->storeParamsBeforeRotation();
auto rot_spo = std::make_unique<RotatedSPOsT<T>>(spo_object_name, std::move(sposet));
auto rot_spo = std::make_unique<RotatedSPOsT<double>>(spo_object_name, std::move(sposet));

if (method == "history")
rot_spo->set_use_global_rotation(false);

Check warning on line 204 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L204

Added line #L204 was not covered by tests
Expand All @@ -178,8 +212,34 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
}
});

Check warning on line 213 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L213

Added line #L213 was not covered by tests
return rot_spo;
#endif
}

template<>
std::unique_ptr<SPOSetT<std::complex<float>>> SPOSetBuilderT<std::complex<float>>::createRotatedSPOSet(xmlNodePtr cur)

Check warning on line 218 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L218

Added line #L218 was not covered by tests
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;

Check warning on line 227 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L227

Added line #L227 was not covered by tests
}

template<>
std::unique_ptr<SPOSetT<std::complex<double>>> SPOSetBuilderT<std::complex<double>>::createRotatedSPOSet(xmlNodePtr cur)

Check warning on line 231 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L231

Added line #L231 was not covered by tests
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;

Check warning on line 240 in src/QMCWaveFunctions/SPOSetBuilderT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetBuilderT.cpp#L240

Added line #L240 was not covered by tests
}

template class SPOSetBuilderT<double>;
template class SPOSetBuilderT<float>;
template class SPOSetBuilderT<std::complex<double>>;
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/SPOSetT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ void SPOSetT<T>::evaluateDerivatives(ParticleSet& P,
template<class T>
void SPOSetT<T>::evaluateDerivativesWF(ParticleSet& P,

Check warning on line 360 in src/QMCWaveFunctions/SPOSetT.cpp

View check run for this annotation

Codecov / codecov/patch

src/QMCWaveFunctions/SPOSetT.cpp#L360

Added line #L360 was not covered by tests
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const typename QTFull::ValueType& psiCurrent,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<T>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/SPOSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ class SPOSetT : public QMCTraits
*/
virtual void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const typename QTFull::ValueType& psiCurrent,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<T>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down

0 comments on commit 9c61923

Please sign in to comment.