Skip to content

Commit 2ef516b

Browse files
authored
Poisson: 2D FFT in x-y plus tridiagonal solve in z-direction (#139)
1 parent 15838d8 commit 2ef516b

File tree

4 files changed

+361
-0
lines changed

4 files changed

+361
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
AMREX_HOME ?= ../../../../amrex
2+
HEFFTE_HOME ?= ../../../../heffte/build
3+
4+
DEBUG = FALSE
5+
DIM = 3
6+
COMP = gcc
7+
TINY_PROFILE = FALSE
8+
USE_MPI = TRUE
9+
USE_CUDA = FALSE
10+
USE_HIP = FALSE
11+
12+
BL_NO_FORT = TRUE
13+
14+
include $(AMREX_HOME)/Tools/GNUMake/Make.defs
15+
include $(AMREX_HOME)/Src/Base/Make.package
16+
include Make.package
17+
18+
VPATH_LOCATIONS += $(HEFFTE_HOME)/include
19+
INCLUDE_LOCATIONS += $(HEFFTE_HOME)/include
20+
LIBRARY_LOCATIONS += $(HEFFTE_HOME)/lib
21+
22+
libraries += -lheffte
23+
24+
ifeq ($(USE_CUDA),TRUE)
25+
libraries += -lcufft
26+
else ifeq ($(USE_HIP),TRUE)
27+
# Use rocFFT. ROC_PATH is defined in amrex
28+
INCLUDE_LOCATIONS += $(ROC_PATH)/rocfft/include
29+
LIBRARY_LOCATIONS += $(ROC_PATH)/rocfft/lib
30+
LIBRARIES += -L$(ROC_PATH)/rocfft/lib -lrocfft
31+
else
32+
libraries += -lfftw3_mpi -lfftw3f -lfftw3
33+
endif
34+
35+
include $(AMREX_HOME)/Tools/GNUMake/Make.rules
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CEXE_sources += main.cpp
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
n_cell_x = 64
2+
n_cell_y = 64
3+
n_cell_z = 64
4+
5+
prob_lo_x = 0.
6+
prob_lo_y = 0.
7+
prob_lo_z = 0.
8+
9+
prob_hi_x = 1.
10+
prob_hi_y = 1.
11+
prob_hi_z = 1.
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
#include <heffte.h>
2+
3+
#include <AMReX.H>
4+
#include <AMReX_MultiFab.H>
5+
#include <AMReX_ParmParse.H>
6+
#include <AMReX_GpuComplex.H>
7+
#include <AMReX_PlotFileUtil.H>
8+
9+
using namespace amrex;
10+
11+
static_assert(AMREX_SPACEDIM == 3);
12+
13+
int main (int argc, char* argv[])
14+
{
15+
amrex::Initialize(argc, argv); {
16+
17+
BL_PROFILE("main");
18+
19+
// **********************************
20+
// DECLARE SIMULATION PARAMETERS
21+
// **********************************
22+
23+
// number of cells on each side of the domain
24+
int n_cell_x;
25+
int n_cell_y;
26+
int n_cell_z;
27+
28+
// physical dimensions of the domain
29+
Real prob_lo_x = 0.;
30+
Real prob_lo_y = 0.;
31+
Real prob_lo_z = 0.;
32+
Real prob_hi_x = 1.;
33+
Real prob_hi_y = 1.;
34+
Real prob_hi_z = 1.;
35+
36+
// **********************************
37+
// READ PARAMETER VALUES FROM INPUTS FILE
38+
// **********************************
39+
{
40+
// ParmParse is way of reading inputs from the inputs file
41+
// pp.get means we require the inputs file to have it
42+
// pp.query means we optionally need the inputs file to have it - but you should supply a default value above
43+
44+
ParmParse pp;
45+
46+
pp.get("n_cell_x",n_cell_x);
47+
pp.get("n_cell_y",n_cell_y);
48+
pp.get("n_cell_z",n_cell_z);
49+
50+
pp.query("prob_lo_x",prob_lo_x);
51+
pp.query("prob_lo_y",prob_lo_y);
52+
pp.query("prob_lo_z",prob_lo_z);
53+
54+
pp.query("prob_hi_x",prob_hi_x);
55+
pp.query("prob_hi_y",prob_hi_y);
56+
pp.query("prob_hi_z",prob_hi_z);
57+
}
58+
59+
// Determine the domain length in each direction
60+
Real L_x = std::abs(prob_hi_x - prob_lo_x);
61+
Real L_y = std::abs(prob_hi_y - prob_lo_y);
62+
63+
// define lower and upper indices of domain
64+
IntVect dom_lo(AMREX_D_DECL( 0, 0, 0));
65+
IntVect dom_hi(AMREX_D_DECL(n_cell_x-1, n_cell_y-1, n_cell_z-1));
66+
67+
// Make a single box that is the entire domain
68+
Box domain(dom_lo, dom_hi);
69+
70+
// Initialize the boxarray "ba" from the single box "domain" There are
71+
// exactly nprocs boxes. The domain decomposition is done in the x- and
72+
// y-directions, but not the z-direction.
73+
BoxArray ba = amrex::decompose(domain, ParallelDescriptor::NProcs(),
74+
{AMREX_D_DECL(true,true,false)});
75+
76+
// How Boxes are distrubuted among MPI processes
77+
DistributionMapping dm(ba);
78+
79+
// This defines the physical box size in each direction
80+
RealBox real_box({ AMREX_D_DECL(prob_lo_x, prob_lo_y, prob_lo_z)},
81+
{ AMREX_D_DECL(prob_hi_x, prob_hi_y, prob_hi_z)} );
82+
83+
// periodic in all direction
84+
Array<int,AMREX_SPACEDIM> is_periodic{AMREX_D_DECL(1,1,1)};
85+
86+
// geometry object for real data
87+
Geometry geom(domain, real_box, CoordSys::cartesian, is_periodic);
88+
89+
// extract dx from the geometry object
90+
GpuArray<Real,AMREX_SPACEDIM> dx = geom.CellSizeArray();
91+
92+
MultiFab rhs(ba,dm,1,0);
93+
MultiFab soln(ba,dm,1,0);
94+
95+
// check to make sure each MPI rank has exactly 1 box
96+
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(rhs.local_size() == 1, "Must have one Box per MPI process");
97+
98+
for (MFIter mfi(rhs); mfi.isValid(); ++mfi) {
99+
100+
Array4<Real> const& rhs_ptr = rhs.array(mfi);
101+
102+
const Box& bx = mfi.fabbox();
103+
104+
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
105+
{
106+
107+
// **********************************
108+
// SET VALUES FOR EACH CELL
109+
// **********************************
110+
111+
Real x = (i+0.5) * dx[0];
112+
Real y = (AMREX_SPACEDIM>=2) ? (j+0.5) * dx[1] : 0.;
113+
Real z = (AMREX_SPACEDIM==3) ? (k+0.5) * dx[2] : 0.;
114+
115+
rhs_ptr(i,j,k) = std::exp(-10.*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)+(z-0.5)*(z-0.5)));
116+
117+
});
118+
}
119+
120+
// Shift rhs so that its sum is zero.
121+
auto rhosum = rhs.sum(0);
122+
rhs.plus(-rhosum/geom.Domain().d_numPts(), 0, 1);
123+
124+
// since there is 1 MPI rank per box, here each MPI rank obtains its local box and the associated boxid
125+
Box local_box;
126+
int local_boxid;
127+
{
128+
for (int i = 0; i < ba.size(); ++i) {
129+
Box b = ba[i];
130+
// each MPI rank has its own local_box Box and local_boxid ID
131+
if (ParallelDescriptor::MyProc() == dm[i]) {
132+
local_box = b;
133+
local_boxid = i;
134+
}
135+
}
136+
}
137+
138+
// now each MPI rank works on its own box
139+
// for real->complex fft's, the fft is stored in an (nx/2+1) x ny x nz dataset
140+
141+
// start by coarsening each box by 2 in the x-direction
142+
Box c_local_box = amrex::coarsen(local_box, IntVect(AMREX_D_DECL(2,1,1)));
143+
144+
// if the coarsened box's high-x index is even, we shrink the size in 1 in x
145+
// this avoids overlap between coarsened boxes
146+
if (c_local_box.bigEnd(0) * 2 == local_box.bigEnd(0)) {
147+
c_local_box.setBig(0,c_local_box.bigEnd(0)-1);
148+
}
149+
// for any boxes that touch the hi-x domain we
150+
// increase the size of boxes by 1 in x
151+
// this makes the overall fft dataset have size (Nx/2+1 x Ny x Nz)
152+
if (local_box.bigEnd(0) == geom.Domain().bigEnd(0)) {
153+
c_local_box.growHi(0,1);
154+
}
155+
156+
// each MPI rank gets storage for its piece of the fft
157+
BaseFab<GpuComplex<Real> > spectral_field(c_local_box, 1, The_Device_Arena());
158+
159+
// create real->complex fft objects with the appropriate backend and data about
160+
// the domain size and its local box size
161+
using fft_r2c_t =
162+
#ifdef AMREX_USE_CUDA
163+
heffte::fft2d_r2c<heffte::backend::cufft>;
164+
#elif AMREX_USE_HIP
165+
heffte::fft2d_r2c<heffte::backend::rocfft>;
166+
#else
167+
heffte::fft2d_r2c<heffte::backend::fftw>;
168+
#endif
169+
170+
auto lo = amrex::lbound(local_box);
171+
auto hi = amrex::ubound(local_box);
172+
auto len = amrex::length(local_box);
173+
auto clo = amrex::lbound(c_local_box);
174+
auto chi = amrex::ubound(c_local_box);
175+
auto clen = amrex::length(c_local_box);
176+
177+
auto fft = std::make_unique<fft_r2c_t>
178+
(heffte::box3d({ lo.x, lo.y, 0},
179+
{ hi.x, hi.y, 0}),
180+
heffte::box3d({clo.x, clo.y, 0},
181+
{chi.x, chi.y, 0}),
182+
0, ParallelDescriptor::Communicator());
183+
184+
Real start_step = static_cast<Real>(ParallelDescriptor::second());
185+
using heffte_complex = typename heffte::fft_output<Real>::type;
186+
heffte_complex* spectral_data = (heffte_complex*) spectral_field.dataPtr();
187+
188+
int batch_size = n_cell_z;
189+
Gpu::DeviceVector<heffte_complex> workspace(fft->size_workspace()*batch_size);
190+
191+
{ BL_PROFILE("HEFFTE-total");
192+
{
193+
BL_PROFILE("ForwardTransform");
194+
fft->forward(batch_size, rhs[local_boxid].dataPtr(), spectral_data,
195+
workspace.data());
196+
}
197+
198+
// Now we take the standard FFT and scale it by 1/k^2
199+
Array4< GpuComplex<Real> > spectral = spectral_field.array();
200+
201+
FArrayBox tridiag_workspace(c_local_box,4);
202+
auto const& ald = tridiag_workspace.array(0);
203+
auto const& bd = tridiag_workspace.array(1);
204+
auto const& cud = tridiag_workspace.array(2);
205+
auto const& scratch = tridiag_workspace.array(3);
206+
207+
Gpu::DeviceVector<Real> delzv(n_cell_z, dx[2]);
208+
auto const* delz = delzv.data();
209+
210+
auto xybox = amrex::makeSlab(c_local_box, 2, 0);
211+
ParallelFor(xybox, [=] AMREX_GPU_DEVICE(int i, int j, int)
212+
{
213+
Real a = 2.*M_PI*i / L_x;
214+
Real b = 2.*M_PI*j / L_y;
215+
216+
// the values in the upper-half of the spectral array in y and z are here interpreted as negative wavenumbers
217+
if (j >= n_cell_y/2) b = 2.*M_PI*(n_cell_y-j) / L_y;
218+
219+
Real k2 = 2*(std::cos(a*dx[0])-1.)/(dx[0]*dx[0]) + 2*(std::cos(b*dx[1])-1.)/(dx[1]*dx[1]);
220+
221+
// Tridiagonal solve with homogeneous Neumann
222+
for( int k=0; k<n_cell_z; k++) {
223+
if(k==0) {
224+
ald(i,j,k) = 0.;
225+
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
226+
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
227+
} else if (k == n_cell_z-1) {
228+
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
229+
cud(i,j,k) = 0.;
230+
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
231+
if (i == 0 && j == 0) {
232+
bd(i,j,k) *= 2.0;
233+
}
234+
} else {
235+
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
236+
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
237+
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
238+
}
239+
}
240+
241+
scratch(i,j,0) = cud(i,j,0)/bd(i,j,0);
242+
spectral(i,j,0) = spectral(i,j,0)/bd(i,j,0);
243+
244+
for (int k = 1; k < n_cell_z; k++) {
245+
if (k < n_cell_z-1){
246+
scratch(i,j,k) = cud(i,j,k) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
247+
}
248+
spectral(i,j,k) = (spectral(i,j,k) - ald(i,j,k) * spectral(i,j,k - 1)) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1));
249+
}
250+
251+
for (int k = n_cell_z - 2; k >= 0; k--) {
252+
spectral(i,j,k) -= scratch(i,j,k) * spectral(i,j,k + 1);
253+
}
254+
});
255+
Gpu::streamSynchronize();
256+
257+
{
258+
BL_PROFILE("BackwardTransform");
259+
fft->backward(batch_size, spectral_data, soln[local_boxid].dataPtr(),
260+
heffte::scale::full);
261+
}
262+
}
263+
264+
Real end_step = static_cast<Real>(ParallelDescriptor::second());
265+
// amrex::Print() << "TIME IN SOLVE " << end_step - start_step << std::endl;
266+
267+
// storage for variables to write to plotfile
268+
MultiFab plotfile(ba, dm, 2, 0);
269+
270+
// copy rhs and soln into plotfile
271+
MultiFab::Copy(plotfile, rhs , 0, 0, 1, 0);
272+
MultiFab::Copy(plotfile, soln, 0, 1, 1, 0);
273+
274+
// time and step are dummy variables required to WriteSingleLevelPlotfile
275+
Real time = 0.;
276+
int step = 0;
277+
278+
// arguments
279+
// 1: name of plotfile
280+
// 2: MultiFab containing data to plot
281+
// 3: variables names
282+
// 4: geometry object
283+
// 5: "time" of plotfile; not relevant in this example
284+
// 6: "time step" of plotfile; not relevant in this example
285+
WriteSingleLevelPlotfile("plt", plotfile, {"rhs", "soln"}, geom, time, step);
286+
287+
{
288+
MultiFab phi(soln.boxArray(), soln.DistributionMap(), 1, 1);
289+
MultiFab res(soln.boxArray(), soln.DistributionMap(), 1, 0);
290+
MultiFab::Copy(phi, soln, 0, 0, 1, 0);
291+
phi.FillBoundary(geom.periodicity());
292+
auto const& res_ma = res.arrays();
293+
auto const& phi_ma = phi.const_arrays();
294+
auto const& rhs_ma = rhs.const_arrays();
295+
ParallelFor(res, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
296+
{
297+
auto const& phia = phi_ma[b];
298+
auto lap = (phia(i-1,j,k)-2.*phia(i,j,k)+phia(i+1,j,k)) / (dx[0]*dx[0])
299+
+ (phia(i,j-1,k)-2.*phia(i,j,k)+phia(i,j+1,k)) / (dx[1]*dx[1]);
300+
if (k == 0) {
301+
lap += (-phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2]);
302+
} else if (k == n_cell_z-1) {
303+
lap += (phia(i,j,k-1)-phia(i,j,k)) / (dx[2]*dx[2]);
304+
} else {
305+
lap += (phia(i,j,k-1)-2.*phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2]);
306+
}
307+
res_ma[b](i,j,k) = rhs_ma[b](i,j,k) - lap;
308+
});
309+
amrex::Print() << " rhs.min & max: " << rhs.min(0) << " " << rhs.max(0) << "\n"
310+
<< " res.min & max: " << res.min(0) << " " << res.max(0) << "\n";
311+
}
312+
313+
} amrex::Finalize();
314+
}

0 commit comments

Comments
 (0)