-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Follow existing pattern for SplineC2C allowing for std::complex<T>
- Loading branch information
1 parent
4e8ec2e
commit 744c05f
Showing
4 changed files
with
1,286 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// This file is distributed under the University of Illinois/NCSA Open Source License. | ||
// See LICENSE file in top directory for details. | ||
// | ||
// Copyright (c) 2019 QMCPACK developers. | ||
// | ||
// File developed by: Jeremy McMinnis, [email protected], University of Illinois at Urbana-Champaign | ||
// Jaron T. Krogel, [email protected], Oak Ridge National Laboratory | ||
// Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign | ||
// Mark A. Berrill, [email protected], Oak Ridge National Laboratory | ||
// Ye Luo, [email protected], Argonne National Laboratory | ||
// | ||
// File created by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
|
||
/** @file BsplineSetT.h | ||
* | ||
* BsplineSet is a SPOSet derived class and serves as a base class for B-spline SPO C2C/C2R/R2R implementation | ||
*/ | ||
#ifndef QMCPLUSPLUS_BSPLINESET_H | ||
#define QMCPLUSPLUS_BSPLINESET_H | ||
|
||
#include "QMCWaveFunctions/SPOSetT.h" | ||
#include "spline/einspline_engine.hpp" | ||
#include "spline/einspline_util.hpp" | ||
|
||
namespace qmcplusplus | ||
{ | ||
/** BsplineSet is the base class for SplineC2C, SplineC2R, SplineR2R. | ||
* Its derived template classes manage the storage and evaluation at given precision. | ||
* BsplineSet also implements a few fallback routines in case optimized implementation is not necessary in the derived class. | ||
*/ | ||
template<class T> | ||
class BsplineSetT : public SPOSetT<T> | ||
{ | ||
public: | ||
using PosType = typename SPOSetT<T>::PosType; | ||
using ValueVector = typename SPOSetT<T>::ValueVector; | ||
using GradVector = typename SPOSetT<T>::GradVector; | ||
using HessVector = typename SPOSetT<T>::HessVector; | ||
using GGGVector = typename SPOSetT<T>::GGGVector; | ||
using ValueMatrix = typename SPOSetT<T>::ValueMatrix; | ||
using GradMatrix = typename SPOSetT<T>::GradMatrix; | ||
using HessMatrix = typename SPOSetT<T>::HessMatrix; | ||
using GGGMatrix = typename SPOSetT<T>::GGGMatrix; | ||
|
||
using value_type = typename SPOSetT<T>::ValueMatrix::value_type; | ||
using grad_type = typename SPOSetT<T>::GradMatrix::value_type; | ||
|
||
// used in derived classes | ||
using RealType = typename SPOSetT<T>::RealType; | ||
using ValueType = typename SPOSetT<T>::ValueType; | ||
|
||
BsplineSetT(const std::string& my_name) : SPOSetT<T>(my_name), MyIndex(0), first_spo(0), last_spo(0) {} | ||
|
||
virtual bool isComplex() const = 0; | ||
virtual std::string getKeyword() const = 0; | ||
|
||
auto& getHalfG() const { return HalfG; } | ||
|
||
inline void init_base(int n) | ||
{ | ||
kPoints.resize(n); | ||
MakeTwoCopies.resize(n); | ||
BandIndexMap.resize(n); | ||
for (int i = 0; i < n; i++) | ||
BandIndexMap[i] = i; | ||
} | ||
|
||
///remap kpoints to group general kpoints & special kpoints | ||
int remap_kpoints() | ||
{ | ||
std::vector<PosType> k_copy(kPoints); | ||
const int nk = kPoints.size(); | ||
int nCB = 0; | ||
//two pass | ||
for (int i = 0; i < nk; ++i) | ||
{ | ||
if (MakeTwoCopies[i]) | ||
{ | ||
kPoints[nCB] = k_copy[i]; | ||
BandIndexMap[nCB++] = i; | ||
} | ||
} | ||
int nRealBands = nCB; | ||
for (int i = 0; i < nk; ++i) | ||
{ | ||
if (!MakeTwoCopies[i]) | ||
{ | ||
kPoints[nRealBands] = k_copy[i]; | ||
BandIndexMap[nRealBands++] = i; | ||
} | ||
} | ||
return nCB; //return the number of complex bands | ||
} | ||
|
||
std::unique_ptr<SPOSetT<T>> makeClone() const override = 0; | ||
|
||
void setOrbitalSetSize(int norbs) override { this->OrbitalSetSize = norbs; } | ||
|
||
void evaluate_notranspose(const ParticleSet& P, | ||
int first, | ||
int last, | ||
ValueMatrix& logdet, | ||
GradMatrix& dlogdet, | ||
ValueMatrix& d2logdet) override | ||
{ | ||
for (int iat = first, i = 0; iat < last; ++iat, ++i) | ||
{ | ||
ValueVector v(logdet[i], logdet.cols()); | ||
GradVector g(dlogdet[i], dlogdet.cols()); | ||
ValueVector l(d2logdet[i], d2logdet.cols()); | ||
this->evaluateVGL(P, iat, v, g, l); | ||
} | ||
} | ||
|
||
void mw_evaluate_notranspose(const RefVectorWithLeader<SPOSetT<T>>& spo_list, | ||
const RefVectorWithLeader<ParticleSet>& P_list, | ||
int first, | ||
int last, | ||
const RefVector<ValueMatrix>& logdet_list, | ||
const RefVector<GradMatrix>& dlogdet_list, | ||
const RefVector<ValueMatrix>& d2logdet_list) const override | ||
{ | ||
assert(this == &spo_list.getLeader()); | ||
const size_t nw = spo_list.size(); | ||
std::vector<ValueVector> mw_psi_v; | ||
std::vector<GradVector> mw_dpsi_v; | ||
std::vector<ValueVector> mw_d2psi_v; | ||
RefVector<ValueVector> psi_v_list; | ||
RefVector<GradVector> dpsi_v_list; | ||
RefVector<ValueVector> d2psi_v_list; | ||
mw_psi_v.reserve(nw); | ||
mw_dpsi_v.reserve(nw); | ||
mw_d2psi_v.reserve(nw); | ||
psi_v_list.reserve(nw); | ||
dpsi_v_list.reserve(nw); | ||
d2psi_v_list.reserve(nw); | ||
|
||
for (int iat = first, i = 0; iat < last; ++iat, ++i) | ||
{ | ||
mw_psi_v.clear(); | ||
mw_dpsi_v.clear(); | ||
mw_d2psi_v.clear(); | ||
psi_v_list.clear(); | ||
dpsi_v_list.clear(); | ||
d2psi_v_list.clear(); | ||
|
||
for (int iw = 0; iw < nw; iw++) | ||
{ | ||
mw_psi_v.emplace_back(logdet_list[iw].get()[i], logdet_list[iw].get().cols()); | ||
mw_dpsi_v.emplace_back(dlogdet_list[iw].get()[i], dlogdet_list[iw].get().cols()); | ||
mw_d2psi_v.emplace_back(d2logdet_list[iw].get()[i], d2logdet_list[iw].get().cols()); | ||
psi_v_list.push_back(mw_psi_v.back()); | ||
dpsi_v_list.push_back(mw_dpsi_v.back()); | ||
d2psi_v_list.push_back(mw_d2psi_v.back()); | ||
} | ||
|
||
this->mw_evaluateVGL(spo_list, P_list, iat, psi_v_list, dpsi_v_list, d2psi_v_list); | ||
} | ||
} | ||
|
||
void evaluate_notranspose(const ParticleSet& P, | ||
int first, | ||
int last, | ||
ValueMatrix& logdet, | ||
GradMatrix& dlogdet, | ||
HessMatrix& grad_grad_logdet) override | ||
{ | ||
for (int iat = first, i = 0; iat < last; ++iat, ++i) | ||
{ | ||
ValueVector v(logdet[i], logdet.cols()); | ||
GradVector g(dlogdet[i], dlogdet.cols()); | ||
HessVector h(grad_grad_logdet[i], grad_grad_logdet.cols()); | ||
this->evaluateVGH(P, iat, v, g, h); | ||
} | ||
} | ||
|
||
void evaluate_notranspose(const ParticleSet& P, | ||
int first, | ||
int last, | ||
ValueMatrix& logdet, | ||
GradMatrix& dlogdet, | ||
HessMatrix& grad_grad_logdet, | ||
GGGMatrix& grad_grad_grad_logdet) override | ||
{ | ||
for (int iat = first, i = 0; iat < last; ++iat, ++i) | ||
{ | ||
ValueVector v(logdet[i], logdet.cols()); | ||
GradVector g(dlogdet[i], dlogdet.cols()); | ||
HessVector h(grad_grad_logdet[i], grad_grad_logdet.cols()); | ||
GGGVector gh(grad_grad_grad_logdet[i], grad_grad_grad_logdet.cols()); | ||
this->evaluateVGHGH(P, iat, v, g, h, gh); | ||
} | ||
} | ||
|
||
void evaluateGradSource(const ParticleSet& P, | ||
int first, | ||
int last, | ||
const ParticleSet& source, | ||
int iat_src, | ||
GradMatrix& gradphi) override | ||
{ | ||
//Do nothing, since Einsplines don't explicitly depend on ion positions. | ||
} | ||
|
||
void evaluateGradSource(const ParticleSet& P, | ||
int first, | ||
int last, | ||
const ParticleSet& source, | ||
int iat_src, | ||
GradMatrix& grad_phi, | ||
HessMatrix& grad_grad_phi, | ||
GradMatrix& grad_lapl_phi) override | ||
{ | ||
//Do nothing, since Einsplines don't explicitly depend on ion positions. | ||
} | ||
|
||
template<class BSPLINESPO> | ||
friend struct SplineSetReader; | ||
friend struct BsplineReaderBase; | ||
|
||
|
||
protected: | ||
static const int D = QMCTraits::DIM; | ||
///Index of this adoptor, when multiple adoptors are used for NUMA or distributed cases | ||
size_t MyIndex; | ||
///first index of the SPOs this Spline handles | ||
size_t first_spo; | ||
///last index of the SPOs this Spline handles | ||
size_t last_spo; | ||
///sign bits at the G/2 boundaries | ||
TinyVector<int, D> HalfG; | ||
///flags to unpack sin/cos | ||
std::vector<bool> MakeTwoCopies; | ||
/** kpoints for each unique orbitals. | ||
* Note: for historic reason, this sign is opposite to what was used in DFT when orbitals were generated. | ||
* Changing the sign requires updating all the evaluation code. | ||
*/ | ||
std::vector<PosType> kPoints; | ||
///remap splines to orbitals | ||
aligned_vector<int> BandIndexMap; | ||
///band offsets used for communication | ||
std::vector<int> offset; | ||
}; | ||
|
||
} // namespace qmcplusplus | ||
#endif |
Oops, something went wrong.